@@ -111,8 +111,9 @@ def _print(self):
111111
112112@dataclass
113113class DecomposeConfig :
114+ method : str
115+ tolerance : int | List [int ]
114116 max_subgraph_size : int = - 1
115- incorrect_models : List [str ] = field (default_factory = list )
116117 tasks_map : Dict [str , Union [int , str , list , dict ]] = field (default_factory = dict )
117118 running_states : Dict [str , Union [int , str , list , dict ]] = field (default_factory = dict )
118119
@@ -139,6 +140,11 @@ def load(self, work_dir):
139140 def get_config_path (self , work_dir ) -> str :
140141 return os .path .join (work_dir , "decompose_config.json" )
141142
143+ def get_incorrect_models (self , pass_id ):
144+ pass_key = get_pass_name (pass_id )
145+ assert pass_key in self .running_states
146+ return self .running_states [pass_key ]["incorrect_models" ]
147+
142148 def update_running_states (self , pass_id , ** kwargs ):
143149 pass_key = get_pass_name (pass_id )
144150 if self .running_states .get (pass_key , None ) is None :
@@ -242,7 +248,6 @@ def run_decomposer_for_multi_models(
242248 )
243249 for model_name , task_info in tasks_map .items ():
244250 original_path = task_info ["original_path" ]
245-
246251 split_positions = sorted (list (task_info ["split_positions" ]))
247252
248253 method = "fixed-start"
@@ -312,9 +317,8 @@ def reconstruct_split_positions_for_subgraphs(
312317
313318 start_pos , end_pos = split_positions [subgraph_idx : subgraph_idx + 2 ]
314319 new_split_positions = new_split_positions + list (
315- range (start_pos , end_pos + max_subgraph_size - 1 , max_subgraph_size )
320+ range (start_pos , end_pos + max_subgraph_size , max_subgraph_size )
316321 )
317-
318322 return sorted (list (set (new_split_positions )))
319323
320324
@@ -353,25 +357,27 @@ def extract_model_name_and_subgraph_idx(subgraph_path):
353357 return model_name , subgraph_idx
354358
355359
356- def collect_incorrect_subgraph_idxs (args , model_names , incorrect_models ):
360+ def collect_incorrect_subgraph_idxs (args , target_model_names , incorrect_models ):
357361 model_name2subgraph_idxs = {}
358362 for subgraph_path in sorted (incorrect_models ):
359363 model_name , subgraph_idx = extract_model_name_and_subgraph_idx (subgraph_path )
360364 print (f"{ subgraph_path = } " )
365+ print (f"{ model_name = } , { subgraph_idx = } " )
366+ assert model_name in target_model_names , f"{ model_name = } , { subgraph_idx = } "
361367
362368 if model_name not in model_name2subgraph_idxs :
363369 model_name2subgraph_idxs [model_name ] = []
364370 model_name2subgraph_idxs [model_name ].append (subgraph_idx )
365371
366372 if args .method == "fixed-start" :
367- for model_name in model_names :
373+ print (model_name2subgraph_idxs )
374+ for model_name in target_model_names :
368375 if model_name not in model_name2subgraph_idxs :
369376 model_name2subgraph_idxs [model_name ] = [1 ]
370377 else :
371- assert (
372- len (model_name2subgraph_idxs [model_name ]) == 1
373- and model_name2subgraph_idxs [model_name ] == 0
374- )
378+ assert len (
379+ model_name2subgraph_idxs [model_name ]
380+ ) == 1 and model_name2subgraph_idxs [model_name ] == [0 ]
375381 return model_name2subgraph_idxs
376382
377383
@@ -382,18 +388,19 @@ def generate_successor_tasks(args, base_output_dir, current_pass_id):
382388
383389 prev_config = DecomposeConfig .load (prev_pass_dir )
384390 max_subgraph_size = prev_config .max_subgraph_size // 2
385- if not prev_config .incorrect_models :
391+ incorrect_models = prev_config .get_incorrect_models (current_pass_id )
392+ if args .method != "fixed-start" and not incorrect_models :
386393 return {}, max_subgraph_size , prev_config .running_states
387394
388395 tasks_map = {}
389396 prev_tasks_map = prev_config .tasks_map
390397
398+ target_model_names = list (prev_tasks_map .keys ())
391399 model_name2subgraph_idxs = collect_incorrect_subgraph_idxs (
392- args , list ( prev_tasks_map . keys ()), prev_config . incorrect_models
400+ args , target_model_names , incorrect_models
393401 )
394402
395403 for model_name , subgraph_idxs in model_name2subgraph_idxs .items ():
396- assert model_name in prev_tasks_map
397404 pre_task_for_model = prev_tasks_map [model_name ]
398405
399406 prev_split_positions = pre_task_for_model .get ("split_positions" , [])
@@ -500,8 +507,7 @@ def count_unique_original_models(incorrect_models):
500507 return len (original_model_paths )
501508
502509
503- def print_summary_and_suggestion (next_round_models , max_subgraph_size ):
504- """Print suggestion/result."""
510+ def print_summary_and_suggestion (args , next_round_models , max_subgraph_size ):
505511 print ("\n " + "=" * 80 )
506512 if next_round_models and max_subgraph_size > 1 :
507513 print (f">>> [SUGGESTION] Issues remain (Count: { len (next_round_models )} )." )
@@ -527,6 +533,8 @@ def main(args):
527533 args , current_pass_id , base_output_dir
528534 )
529535 decompose_config = DecomposeConfig (
536+ method = args .method ,
537+ tolerance = args .tolerance ,
530538 max_subgraph_size = max_subgraph_size ,
531539 tasks_map = tasks_map ,
532540 running_states = running_states ,
@@ -559,7 +567,6 @@ def main(args):
559567 run_evaluation (args .framework , args .test_config , work_dir , log_path )
560568
561569 # --- Step 4: Analysis ---
562- next_pass_incorrect_models = set ()
563570 if task_controller .task_scheduler ["post_analysis" ]:
564571 tolerance = (
565572 args .tolerance [0 ] if isinstance (args .tolerance , list ) else args .tolerance
@@ -572,15 +579,17 @@ def main(args):
572579 num_incorrect_models = num_original_models ,
573580 incorrect_models = list (next_pass_incorrect_models ),
574581 )
582+
575583 print (
576584 f"[Analysis] Found { len (next_pass_incorrect_models )} incorrect subgraphs ({ num_original_models } original models)."
577585 )
578586 for idx , model_path in enumerate (next_pass_incorrect_models ):
579587 print (f"- [{ idx } ] { model_path } " )
580- print_summary_and_suggestion (next_pass_incorrect_models , max_subgraph_size )
588+ print_summary_and_suggestion (
589+ args , next_pass_incorrect_models , max_subgraph_size
590+ )
581591
582592 # --- Step 5: Save States ---
583- decompose_config .incorrect_models = list (next_pass_incorrect_models )
584593 decompose_config .save (work_dir )
585594
586595
0 commit comments