1212from graph_net .analysis_util import get_incorrect_models
1313from graph_net import path_utils
1414
15- kMaxGraphSize = 4096
15+ MAX_GRAPH_SIZE = 4096
1616
1717
1818def convert_b64_string_to_json (b64str ):
@@ -109,9 +109,30 @@ def _print(self):
109109 print ()
110110
111111
112+ @dataclass
113+ class ModelRecord :
114+ original_path : str
115+ uniform_split_positions : List [int ] = field (default_factory = list )
116+ subgraph_paths : List [str ] = field (default_factory = list )
117+ incorrect_subgraph_idxs : List [int ] = field (default_factory = list )
118+
119+ def get_split_positions (self , decompose_method ):
120+ if decompose_method == "fixed-start" :
121+ assert (
122+ len (self .uniform_split_positions ) >= 2
123+ ), f"{ self .uniform_split_positions = } "
124+ return [0 , self .uniform_split_positions [1 ]]
125+ return self .uniform_split_positions
126+
127+ def update_for_next_decompose (self , subgraph_idx , max_subgraph_size ):
128+ self .uniform_split_positions = reconstruct_split_positions_for_subgraphs (
129+ self .uniform_split_positions , subgraph_idx , max_subgraph_size
130+ )
131+
132+
112133@dataclass
113134class DecomposeConfig :
114- method : str
135+ decompose_method : str
115136 tolerance : int | List [int ]
116137 max_subgraph_size : int = - 1
117138 tasks_map : Dict [str , Union [int , str , list , dict ]] = field (default_factory = dict )
@@ -145,18 +166,28 @@ def get_incorrect_models(self, pass_id):
145166 assert pass_key in self .running_states
146167 return self .running_states [pass_key ]["incorrect_models" ]
147168
148- def update_running_states (self , pass_id , ** kwargs ):
149- pass_key = get_pass_name (pass_id )
150- if self .running_states .get (pass_key , None ) is None :
169+ def update_running_states (self , pass_id , incorrect_models , model_name2record ):
170+ assert pass_id == "initial" or isinstance (pass_id , int )
171+ pass_key = get_pass_name (pass_id ) if isinstance (pass_id , int ) else pass_id
172+ if pass_key not in self .running_states :
151173 self .running_states [pass_key ] = {}
152174
153- for key , value in kwargs .items ():
154- assert key in [
155- "num_incorrect_models" ,
156- "incorrect_models" ,
157- "failed_decomposition_models" ,
158- ]
159- self .running_states [pass_key ][key ] = value
175+ self .running_states [pass_key ]["incorrect_models_from_log" ] = list (
176+ sorted (incorrect_models )
177+ )
178+ if model_name2record :
179+ target_model_names = list (model_name2record .keys ())
180+ model_name2subgraph_idxs = collect_incorrect_subgraph_idxs (
181+ self .decompose_method ,
182+ target_model_names ,
183+ incorrect_models ,
184+ model_name2record ,
185+ )
186+ for model_name , model_record in sorted (model_name2record .items ()):
187+ model_record .incorrect_subgraph_idxs = model_name2subgraph_idxs [
188+ model_name
189+ ]
190+ self .running_states [pass_key ][model_name ] = model_record .__dict__
160191
161192
162193def get_rectfied_model_path (model_path ):
@@ -226,21 +257,18 @@ def run_decomposer_for_single_model(
226257
227258
228259def run_decomposer_for_multi_models (
229- framework , tasks_map , decomposed_samples_dir , max_subgraph_size , log_path
260+ framework , model_name2record , decomposed_samples_dir , max_subgraph_size , log_path
230261):
231- failed_decomposition = []
262+ failed_decomposition_models = []
232263
233264 print (
234265 f"[Decomposition] max_subgraph_size: { max_subgraph_size } , log_path: { log_path } "
235266 )
236- for model_name , task_info in tasks_map .items ():
237- original_path = task_info ["original_path" ]
238- split_positions = sorted (list (task_info ["split_positions" ]))
239-
240- method = "fixed-start"
241- if method == "fixed-start" :
242- assert len (split_positions ) >= 3 , f"{ split_positions = } "
243- split_positions = [0 , split_positions [1 ]]
267+ for model_name , model_record in model_name2record .items ():
268+ original_path = model_record .original_path
269+ split_positions = model_record .get_split_positions (
270+ decompose_method = "fixed-start"
271+ )
244272
245273 rectified_model_path = get_rectfied_model_path (original_path )
246274 assert os .path .exists (
@@ -255,8 +283,8 @@ def run_decomposer_for_multi_models(
255283 log_path ,
256284 )
257285 if not success :
258- failed_decomposition .append (rectified_model_path )
259- return tasks_map , failed_decomposition
286+ failed_decomposition_models .append (rectified_model_path )
287+ return failed_decomposition_models
260288
261289
262290def run_evaluation (
@@ -314,10 +342,13 @@ def generate_initial_tasks(args):
314342 initial_failures = get_ranged_incorrect_models (args .tolerance , args .log_file )
315343
316344 tasks_map = {}
317- max_subgraph_size = min (args .max_subgraph_size , kMaxGraphSize // 2 )
345+ if args .decompose_method == "fixed-start" :
346+ max_subgraph_size = MAX_GRAPH_SIZE
347+ else :
348+ max_subgraph_size = min (args .max_subgraph_size , MAX_GRAPH_SIZE )
318349
319350 initial_split_positions = reconstruct_split_positions_for_subgraphs (
320- [0 , kMaxGraphSize ], 0 , max_subgraph_size
351+ [0 , MAX_GRAPH_SIZE ], 0 , max_subgraph_size
321352 )
322353 for model_path in initial_failures :
323354 model_name = get_model_name_with_subgraph_tag (model_path )
@@ -327,7 +358,7 @@ def generate_initial_tasks(args):
327358 }
328359
329360 running_states = {
330- "pass_0 " : {
361+ "initial " : {
331362 "num_incorrect_models" : len (initial_failures ),
332363 "incorrect_models" : list (sorted (initial_failures )),
333364 }
@@ -343,7 +374,9 @@ def extract_model_name_and_subgraph_idx(subgraph_path):
343374 return model_name , subgraph_idx
344375
345376
346- def collect_incorrect_subgraph_idxs (args , target_model_names , incorrect_models ):
377+ def collect_incorrect_subgraph_idxs (
378+ decompose_method , target_model_names , incorrect_models , model_name2record
379+ ):
347380 model_name2subgraph_idxs = {}
348381 for subgraph_path in sorted (incorrect_models ):
349382 model_name , subgraph_idx = extract_model_name_and_subgraph_idx (subgraph_path )
@@ -355,11 +388,17 @@ def collect_incorrect_subgraph_idxs(args, target_model_names, incorrect_models):
355388 model_name2subgraph_idxs [model_name ] = []
356389 model_name2subgraph_idxs [model_name ].append (subgraph_idx )
357390
358- if args . method == "fixed-start" :
391+ if decompose_method == "fixed-start" :
359392 print (model_name2subgraph_idxs )
360393 for model_name in target_model_names :
361394 if model_name not in model_name2subgraph_idxs :
362- model_name2subgraph_idxs [model_name ] = [1 ]
395+ if (
396+ model_name2record
397+ and len (model_name2record [model_name ].uniform_split_positions ) > 2
398+ ):
399+ model_name2subgraph_idxs [model_name ] = [1 ]
400+ else :
401+ model_name2subgraph_idxs [model_name ] = []
363402 else :
364403 assert len (
365404 model_name2subgraph_idxs [model_name ]
@@ -375,15 +414,15 @@ def generate_successor_tasks(args, base_output_dir, current_pass_id):
375414 prev_config = DecomposeConfig .load (prev_pass_dir )
376415 max_subgraph_size = prev_config .max_subgraph_size // 2
377416 incorrect_models = prev_config .get_incorrect_models (current_pass_id )
378- if args .method != "fixed-start" and not incorrect_models :
417+ if args .decompose_method != "fixed-start" and not incorrect_models :
379418 return {}, max_subgraph_size , prev_config .running_states
380419
381420 tasks_map = {}
382421 prev_tasks_map = prev_config .tasks_map
383422
384423 target_model_names = list (prev_tasks_map .keys ())
385424 model_name2subgraph_idxs = collect_incorrect_subgraph_idxs (
386- args , target_model_names , incorrect_models
425+ args . decompose_method , target_model_names , incorrect_models , None
387426 )
388427
389428 for model_name , subgraph_idxs in model_name2subgraph_idxs .items ():
@@ -393,6 +432,8 @@ def generate_successor_tasks(args, base_output_dir, current_pass_id):
393432 split_positions = reconstruct_split_positions_for_subgraphs (
394433 prev_split_positions , subgraph_idxs , max_subgraph_size
395434 )
435+ if args .decompose_method == "fixed-start" and len (split_positions ) > 3 :
436+ split_positions = split_positions [0 :3 ]
396437
397438 tasks_map [model_name ] = {
398439 "original_path" : pre_task_for_model ["original_path" ],
@@ -430,58 +471,76 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
430471 return tasks_map , max_subgraph_size , running_states
431472
432473
433- def execute_decomposition_phase (max_subgraph_size , tasks_map , framework , workspace ):
474+ def collect_decomposed_subgraphs (model_name2record , decomposed_samples_dir ):
475+ for root , dirs , files in os .walk (decomposed_samples_dir ):
476+ if path_utils .is_single_model_dir (root ):
477+ model_name , _ = extract_model_name_and_subgraph_idx (root )
478+ assert model_name in model_name2record
479+ model_record = model_name2record [model_name ]
480+ model_record .subgraph_paths .append (root )
481+ return model_name2record
482+
483+
484+ def execute_decomposition_phase (
485+ max_subgraph_size , model_name2record , framework , workspace
486+ ):
434487 """Executes the decomposition phase."""
435488
436- failed_decomposition = []
437- need_decompose = True if len (tasks_map ) > 0 else False
438- method = "fixed-start"
489+ failed_decomposition_models = []
490+ need_decompose = True if len (model_name2record ) > 0 else False
491+ decompose_method = "fixed-start"
492+ decomposed_samples_dir = os .path .join (
493+ workspace , "samples" if framework == "torch" else "paddle_samples"
494+ )
439495
440496 while need_decompose :
441- decomposed_samples_dir = os .path .join (
442- workspace , "samples" if framework == "torch" else "paddle_samples"
443- )
444497 if not os .path .exists (decomposed_samples_dir ):
445498 os .makedirs (decomposed_samples_dir , exist_ok = True )
446499 print (f"[Decomposition] decomposed_samples_dir: { decomposed_samples_dir } " )
447500
448501 log_path = os .path .join (
449502 workspace , f"log_decompose-max_subgraph_size_{ max_subgraph_size } .txt"
450503 )
451- tasks_map , failed_decomposition = run_decomposer_for_multi_models (
452- framework , tasks_map , decomposed_samples_dir , max_subgraph_size , log_path
504+ failed_decomposition_models = run_decomposer_for_multi_models (
505+ framework ,
506+ model_name2record ,
507+ decomposed_samples_dir ,
508+ max_subgraph_size ,
509+ log_path ,
453510 )
454511 num_decomposed_samples = count_samples (decomposed_samples_dir )
455512 print (
456- f"[Decomposition] number of graphs: { len (tasks_map )} -> { num_decomposed_samples } " ,
513+ f"[Decomposition] number of graphs: { len (model_name2record )} -> { num_decomposed_samples } " ,
457514 flush = True ,
458515 )
459516 if (
460- not failed_decomposition
461- and num_decomposed_samples == len (tasks_map )
517+ not failed_decomposition_models
518+ and num_decomposed_samples == len (model_name2record )
462519 and max_subgraph_size > 1
463- and method != "fixed-start"
520+ and decompose_method != "fixed-start"
464521 ):
465522 need_decompose = True
466523 shutil .rmtree (decomposed_samples_dir )
467524 os .makedirs (decomposed_samples_dir , exist_ok = True )
468525 max_subgraph_size = max (1 , max_subgraph_size // 2 )
469- for model_name , task_info in tasks_map .items ():
470- split_positions = task_info ["split_positions" ]
471- if not split_positions or len (split_positions ) < 2 :
526+ for model_name , model_record in model_name2record .items ():
527+ if (
528+ not model_record .uniform_split_positions
529+ or len (model_record .uniform_split_positions ) < 2
530+ ):
472531 continue
473- new_split_positions = reconstruct_split_positions_for_subgraphs (
474- split_positions , 0 , max_subgraph_size
475- )
476- task_info ["split_positions" ] = new_split_positions
532+ model_record .update_for_next_decompose (0 , max_subgraph_size )
477533 else :
478534 need_decompose = False
479535 print ()
480536
481- if failed_decomposition :
482- print (f"[WARN] { len (failed_decomposition )} models failed to decompose." )
537+ if failed_decomposition_models :
538+ print (f"[WARN] { len (failed_decomposition_models )} models failed to decompose." )
483539
484- return tasks_map , failed_decomposition , max_subgraph_size
540+ model_name2record = collect_decomposed_subgraphs (
541+ model_name2record , decomposed_samples_dir
542+ )
543+ return model_name2record , max_subgraph_size
485544
486545
487546def count_unique_original_models (incorrect_models ):
@@ -518,8 +577,16 @@ def main(args):
518577 tasks_map , max_subgraph_size , running_states = prepare_tasks_and_verify (
519578 args , current_pass_id , base_output_dir
520579 )
580+
581+ model_name2record = {}
582+ for model_name in tasks_map .keys ():
583+ model_name2record [model_name ] = ModelRecord (
584+ original_path = tasks_map [model_name ]["original_path" ],
585+ uniform_split_positions = tasks_map [model_name ]["split_positions" ],
586+ )
587+
521588 decompose_config = DecomposeConfig (
522- method = args .method ,
589+ decompose_method = args .decompose_method ,
523590 tolerance = args .tolerance ,
524591 max_subgraph_size = max_subgraph_size ,
525592 tasks_map = tasks_map ,
@@ -533,14 +600,10 @@ def main(args):
533600 if task_controller .task_scheduler ["run_decomposer" ]:
534601 print ("\n --- Phase 1: Decomposition ---" , flush = True )
535602 (
536- tasks_map ,
537- failed_decomposition ,
603+ model_name2record ,
538604 max_subgraph_size ,
539605 ) = execute_decomposition_phase (
540- max_subgraph_size , tasks_map , args .framework , work_dir
541- )
542- decompose_config .update_running_states (
543- current_pass_id , failed_decomposition_models = list (failed_decomposition )
606+ max_subgraph_size , model_name2record , args .framework , work_dir
544607 )
545608 else :
546609 print ("\n --- Phase 1: Decomposition (skipped) ---" , flush = True )
@@ -560,22 +623,26 @@ def main(args):
560623 print (f"\n --- Phase 3: Analysis (torlance={ tolerance } ) ---" )
561624 next_pass_incorrect_models = sorted (get_incorrect_models (tolerance , log_path ))
562625 num_original_models = count_unique_original_models (next_pass_incorrect_models )
626+
563627 decompose_config .update_running_states (
564- current_pass_id + 1 ,
565- num_incorrect_models = num_original_models ,
566- incorrect_models = list ( next_pass_incorrect_models ) ,
628+ current_pass_id ,
629+ next_pass_incorrect_models ,
630+ model_name2record ,
567631 )
568632
569633 print (
570634 f"[Analysis] Found { len (next_pass_incorrect_models )} incorrect subgraphs ({ num_original_models } original models)."
571635 )
572636 for idx , model_path in enumerate (next_pass_incorrect_models ):
573637 print (f"- [{ idx } ] { model_path } " )
638+
574639 print_summary_and_suggestion (
575640 args , next_pass_incorrect_models , max_subgraph_size
576641 )
577642
578643 # --- Step 5: Save States ---
644+ for model_name , model_record in model_name2record .items ():
645+ print (f"- { model_name } : { model_record } " )
579646 decompose_config .save (work_dir )
580647
581648
@@ -587,7 +654,7 @@ def main(args):
587654 parser .add_argument (
588655 "--test-config" , type = str , required = True , help = "Base64 encoded test config"
589656 )
590- parser .add_argument ("--method" , type = str , required = True )
657+ parser .add_argument ("--decompose- method" , type = str , required = True )
591658 parser .add_argument (
592659 "--tolerance" ,
593660 type = int ,
0 commit comments