@@ -123,17 +123,16 @@ def save_decompose_config(
123123):
124124 """Saves the current state to a JSON file."""
125125
126- active_models_map = {}
127- split_positions_map = {}
126+ tasks_map_copy = {}
128127 for model_name , task_info in tasks_map .items ():
129- active_models_map [model_name ] = task_info ["original_path" ]
130- split_positions_map [model_name ] = task_info ["split_positions" ]
128+ tasks_map_copy [model_name ] = {}
129+ for key in ["original_path" , "split_positions" ]:
130+ tasks_map_copy [model_name ][key ] = task_info [key ]
131131
132132 config = {
133133 "max_subgraph_size" : max_subgraph_size ,
134134 "incorrect_models" : list (incorrect_paths ),
135- "active_models_map" : active_models_map ,
136- "split_positions_map" : split_positions_map ,
135+ "tasks_map" : tasks_map_copy ,
137136 "failed_decomposition_models" : list (failed_decomposition_models ),
138137 }
139138 config_path = get_decompose_config_path (work_dir )
@@ -283,9 +282,8 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
283282 print (f"[Init] Resuming from Pass_{ current_pass_id - 1 } (Dir: { prev_pass_dir } )..." )
284283
285284 prev_config = load_decompose_config (prev_pass_dir )
286- prev_active_models_map = prev_config .get ("active_models_map" , {})
287- prev_split_positions_map = prev_config .get ("split_positions_map" , {})
288285 prev_incorrect_subgraphs = prev_config .get ("incorrect_models" , [])
286+ prev_tasks_map = prev_config .get ("tasks_map" , {})
289287
290288 # Load previous max size as fallback
291289 prev_max_subgraph_size = prev_config .get ("max_subgraph_size" )
@@ -294,20 +292,18 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
294292 if not prev_incorrect_subgraphs :
295293 return {}, max_subgraph_size
296294
297- print ("[Analysis] Refining splits based on previous incorrect models ..." )
298-
299295 tasks_map = {}
300296 for subgraph_path in prev_incorrect_subgraphs :
301297 # Parse model name and subgraph index
302298 model_name_with_subgraph_idx = subgraph_path .rstrip ("/" ).split (os .sep )[- 1 ]
303299 model_name = "_" .join (model_name_with_subgraph_idx .split ("_" )[:- 1 ])
304300 subgraph_idx = int (model_name_with_subgraph_idx .split ("_" )[- 1 ])
305- print ( f"subgraph_path: { subgraph_path } " )
306- print ( f"model_name: { model_name } , subgraph_idx: { subgraph_idx } " )
307- assert model_name in prev_active_models_map
301+
302+ assert model_name in prev_tasks_map
303+ pre_task_for_model = prev_tasks_map [ model_name ]
308304
309305 # Reconstruct previous subgraph size to locate the failing segment
310- prev_split_positions = prev_split_positions_map .get (model_name , [])
306+ prev_split_positions = pre_task_for_model .get ("split_positions" , [])
311307 subgraph_size = reconstruct_subgraph_size (prev_split_positions )
312308 assert subgraph_idx < len (
313309 subgraph_size
@@ -316,7 +312,7 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
316312 if model_name not in tasks_map :
317313 tasks_map [model_name ] = {
318314 "subgraph_path" : subgraph_path ,
319- "original_path" : prev_active_models_map [ model_name ],
315+ "original_path" : pre_task_for_model [ "original_path" ],
320316 "subgraph_size" : subgraph_size [subgraph_idx ],
321317 "split_positions" : set (),
322318 }
@@ -338,7 +334,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, pass_wo
338334 )
339335 if not os .path .exists (decomposed_samples_dir ):
340336 os .makedirs (decomposed_samples_dir , exist_ok = True )
341- print (f"decomposed_samples_dir: { decomposed_samples_dir } " )
337+ print (f"- decomposed_samples_dir: { decomposed_samples_dir } " )
342338
343339 for model_name , task_info in tasks_map .items ():
344340 print (f"[Decomposition] max_subgraph_size: { max_subgraph_size } " )
@@ -387,8 +383,8 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, pass_wo
387383 return tasks_map , failed_decomposition , max_subgraph_size
388384
389385
390- def print_final_summary (next_round_models , max_subgraph_size ):
391- """Prints the final suggestion/result."""
386+ def print_summary_and_suggestion (next_round_models , max_subgraph_size ):
387+ """Print suggestion/result."""
392388 print ("\n " + "=" * 80 )
393389 if next_round_models and max_subgraph_size > 1 :
394390 print (f">>> [SUGGESTION] Issues remain (Count: { len (next_round_models )} )." )
@@ -447,6 +443,11 @@ def main(args):
447443 ) = execute_decomposition_phase (
448444 max_subgraph_size , tasks_map , args .framework , pass_work_dir
449445 )
446+ else :
447+ config = load_decompose_config (pass_work_dir )
448+ max_subgraph_size = config ["max_subgraph_size" ]
449+ failed_decomposition = config ["failed_decomposition_models" ]
450+ tasks_map = config .get ("tasks_map" , {})
450451
451452 # --- Step 4: Testing ---
452453 pass_log_path = os .path .join (pass_work_dir , "batch_test_result.log" )
@@ -470,7 +471,7 @@ def main(args):
470471 failed_decomposition ,
471472 )
472473
473- print_final_summary (next_round_models , max_subgraph_size )
474+ print_summary_and_suggestion (next_round_models , max_subgraph_size )
474475
475476
476477if __name__ == "__main__" :
0 commit comments