@@ -123,17 +123,16 @@ def save_decompose_config(
123123):
124124 """Saves the current state to a JSON file."""
125125
126- active_models_map = {}
127- split_positions_map = {}
126+ tasks_map_copy = {}
128127 for model_name , task_info in tasks_map .items ():
129- active_models_map [model_name ] = task_info ["original_path" ]
130- split_positions_map [model_name ] = task_info ["split_positions" ]
128+ tasks_map_copy [model_name ] = {}
129+ for key in ["original_path" , "split_positions" ]:
130+ tasks_map_copy [model_name ][key ] = task_info [key ]
131131
132132 config = {
133133 "max_subgraph_size" : max_subgraph_size ,
134134 "incorrect_models" : list (incorrect_paths ),
135- "active_models_map" : active_models_map ,
136- "split_positions_map" : split_positions_map ,
135+ "tasks_map" : tasks_map_copy ,
137136 "failed_decomposition_models" : list (failed_decomposition_models ),
138137 }
139138 config_path = get_decompose_config_path (work_dir )
@@ -283,9 +282,8 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
283282 print (f"[Init] Resuming from Pass_{ current_pass_id - 1 } (Dir: { prev_pass_dir } )..." )
284283
285284 prev_config = load_decompose_config (prev_pass_dir )
286- prev_active_models_map = prev_config .get ("active_models_map" , {})
287- prev_split_positions_map = prev_config .get ("split_positions_map" , {})
288285 prev_incorrect_subgraphs = prev_config .get ("incorrect_models" , [])
286+ prev_tasks_map = prev_config .get ("tasks_map" , {})
289287
290288 # Load previous max size as fallback
291289 prev_max_subgraph_size = prev_config .get ("max_subgraph_size" )
@@ -302,12 +300,12 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
302300 model_name_with_subgraph_idx = subgraph_path .rstrip ("/" ).split (os .sep )[- 1 ]
303301 model_name = "_" .join (model_name_with_subgraph_idx .split ("_" )[:- 1 ])
304302 subgraph_idx = int (model_name_with_subgraph_idx .split ("_" )[- 1 ])
305- print ( f"subgraph_path: { subgraph_path } " )
306- print ( f"model_name: { model_name } , subgraph_idx: { subgraph_idx } " )
307- assert model_name in prev_active_models_map
303+
304+ assert model_name in prev_tasks_map
305+ pre_task_for_model = prev_tasks_map [ model_name ]
308306
309307 # Reconstruct previous subgraph size to locate the failing segment
310- prev_split_positions = prev_split_positions_map .get (model_name , [])
308+ prev_split_positions = pre_task_for_model .get ("split_positions" , [])
311309 subgraph_size = reconstruct_subgraph_size (prev_split_positions )
312310 assert subgraph_idx < len (
313311 subgraph_size
@@ -316,7 +314,7 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
316314 if model_name not in tasks_map :
317315 tasks_map [model_name ] = {
318316 "subgraph_path" : subgraph_path ,
319- "original_path" : prev_active_models_map [ model_name ],
317+ "original_path" : pre_task_for_model [ "original_path" ],
320318 "subgraph_size" : subgraph_size [subgraph_idx ],
321319 "split_positions" : set (),
322320 }
@@ -447,6 +445,11 @@ def main(args):
447445 ) = execute_decomposition_phase (
448446 max_subgraph_size , tasks_map , args .framework , pass_work_dir
449447 )
448+ else :
449+ config = load_decompose_config (pass_work_dir )
450+ max_subgraph_size = config ["max_subgraph_size" ]
451+ failed_decomposition = config ["failed_decomposition_models" ]
452+ tasks_map = config .get ("tasks_map" , {})
450453
451454 # --- Step 4: Testing ---
452455 pass_log_path = os .path .join (pass_work_dir , "batch_test_result.log" )
0 commit comments