@@ -117,12 +117,18 @@ def load_decompose_config(work_dir: str) -> Dict[str, Any]:
117117def save_decompose_config (
118118 work_dir : str ,
119119 max_subgraph_size : int ,
120+ tasks_map : Dict [str , Union [int , str , list , dict ]],
120121 incorrect_paths : Union [List [str ], Set [str ]],
121- active_models_map : Dict [str , str ],
122- split_positions_map : Dict [str , List [int ]],
123122 failed_decomposition_models : Union [List [str ], Set [str ]],
124123):
125124 """Saves the current state to a JSON file."""
125+
126+ active_models_map = {}
127+ split_positions_map = {}
128+ for model_name , task_info in tasks_map .items ():
129+ active_models_map [model_name ] = task_info ["original_path" ]
130+ split_positions_map [model_name ] = task_info ["split_positions" ]
131+
126132 config = {
127133 "max_subgraph_size" : max_subgraph_size ,
128134 "incorrect_models" : list (incorrect_paths ),
@@ -143,7 +149,7 @@ def get_model_name_with_subgraph_tag(model_path):
143149 return f"{ fields [- 2 ]} _{ fields [- 1 ]} " if re .match (pattern , fields [- 1 ]) else fields [- 1 ]
144150
145151
146- def run_decomposer (
152+ def run_naive_decomposer (
147153 framework : str ,
148154 model_path : str ,
149155 output_dir : str ,
@@ -170,8 +176,8 @@ def run_decomposer(
170176 json .dumps (decorator_config ).encode ()
171177 ).decode ()
172178
173- print (f"[Decomposing ] model_path: { model_path } " )
174- print (f"[Decomposing ] split_positions: { split_positions } " )
179+ print (f"[Decomposition ] model_path: { model_path } " )
180+ print (f"[Decomposition ] split_positions: { split_positions } " )
175181
176182 cmd = [
177183 sys .executable ,
@@ -185,13 +191,13 @@ def run_decomposer(
185191 result = subprocess .run (
186192 cmd , stdout = subprocess .PIPE , stderr = subprocess .PIPE , text = True
187193 )
194+ # print(result.stdout)
188195 if result .returncode != 0 :
189196 print (
190197 f"[ERROR] Decomposition failed for { model_path } \n { result .stderr } " ,
191198 flush = True ,
192199 )
193200 return False
194- # print(result.stdout)
195201 return True
196202
197203
@@ -215,8 +221,8 @@ def run_evaluation(
215221 for item in (f"--{ key } " , str (value ))
216222 ]
217223
218- print (f"[Batch Testing ] Logging to: { log_path } " )
219- print (f"[Command] { ' ' .join (cmd )} " )
224+ print (f"[Evaluation ] Logging to: { log_path } " )
225+ print (f"[Evaluation] command: { ' ' .join (cmd )} " )
220226
221227 os .makedirs (os .path .dirname (log_path ), exist_ok = True )
222228 with open (log_path , "w" ) as f :
@@ -257,19 +263,18 @@ def generate_initial_tasks(args):
257263 initial_failures = get_incorrect_models (args .tolerance , args .log_file )
258264
259265 tasks_map = {}
260- active_models_map_for_save = {}
261266
262267 for model_path in initial_failures :
263- model_name = os .path .basename (model_path .rstrip ("/" ))
264- active_models_map_for_save [model_name ] = model_path
268+ model_name = get_model_name_with_subgraph_tag (model_path )
265269 tasks_map [model_name ] = {
266270 "subgraph_path" : model_path ,
267271 "original_path" : model_path ,
268272 "subgraph_size" : [0 , kMaxGraphSize ],
273+ "split_positions" : set (),
269274 }
270275
271276 max_subgraph_size = args .max_subgraph_size
272- return tasks_map , active_models_map_for_save , max_subgraph_size
277+ return tasks_map , max_subgraph_size
273278
274279
275280def generate_refined_tasks (base_output_dir , current_pass_id ):
@@ -286,22 +291,20 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
286291 prev_max_subgraph_size = prev_config .get ("max_subgraph_size" )
287292 max_subgraph_size = prev_max_subgraph_size // 2
288293
289- if not prev_incorrect_subgraphs or prev_max_subgraph_size <= 1 :
290- return {}, {}, max_subgraph_size
294+ if not prev_incorrect_subgraphs :
295+ return {}, max_subgraph_size
291296
292297 print ("[Analysis] Refining splits based on previous incorrect models ..." )
293298
294299 tasks_map = {}
295- active_models_map_for_save = {}
296-
297300 for subgraph_path in prev_incorrect_subgraphs :
298301 # Parse model name and subgraph index
299302 model_name_with_subgraph_idx = subgraph_path .rstrip ("/" ).split (os .sep )[- 1 ]
300303 model_name = "_" .join (model_name_with_subgraph_idx .split ("_" )[:- 1 ])
301304 subgraph_idx = int (model_name_with_subgraph_idx .split ("_" )[- 1 ])
302-
305+ print (f"subgraph_path: { subgraph_path } " )
306+ print (f"model_name: { model_name } , subgraph_idx: { subgraph_idx } " )
303307 assert model_name in prev_active_models_map
304- active_models_map_for_save [model_name ] = prev_active_models_map [model_name ]
305308
306309 # Reconstruct previous subgraph size to locate the failing segment
307310 prev_split_positions = prev_split_positions_map .get (model_name , [])
@@ -315,15 +318,15 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
315318 "subgraph_path" : subgraph_path ,
316319 "original_path" : prev_active_models_map [model_name ],
317320 "subgraph_size" : subgraph_size [subgraph_idx ],
321+ "split_positions" : set (),
318322 }
319323
320- return tasks_map , active_models_map_for_save , max_subgraph_size
324+ return tasks_map , max_subgraph_size
321325
322326
323327def execute_decomposition_phase (max_subgraph_size , tasks_map , framework , pass_work_dir ):
324328 """Executes the decomposition phase (Phase 1)."""
325329 failed_decomposition = []
326- final_used_splits_map = {}
327330
328331 need_decompose = True if len (tasks_map ) > 0 else False
329332 if need_decompose :
@@ -338,27 +341,27 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, pass_wo
338341 print (f"decomposed_samples_dir: { decomposed_samples_dir } " )
339342
340343 for model_name , task_info in tasks_map .items ():
341- print (f"[Decomposing ] max_subgraph_size: { max_subgraph_size } " )
344+ print (f"[Decomposition ] max_subgraph_size: { max_subgraph_size } " )
342345 original_path = task_info ["original_path" ]
343346 split_positions = calculate_split_positions_for_subgraph (
344347 task_info ["subgraph_size" ], max_subgraph_size
345348 )
346- final_used_splits_map [ model_name ] = split_positions
349+ task_info [ "split_positions" ] = split_positions
347350
348351 rectified_model_path = get_rectfied_model_path (original_path )
349352 assert os .path .exists (
350353 rectified_model_path
351354 ), f"{ rectified_model_path } does not exist."
352355
353- success = run_decomposer (
356+ success = run_naive_decomposer (
354357 framework , rectified_model_path , decomposed_samples_dir , split_positions
355358 )
356359 if not success :
357360 failed_decomposition .append (rectified_model_path )
358361
359362 num_decomposed_samples = count_samples (decomposed_samples_dir )
360363 print (
361- f"[Decomposing ] number of graphs: { len (tasks_map )} -> { num_decomposed_samples } " ,
364+ f"[Decomposition ] number of graphs: { len (tasks_map )} -> { num_decomposed_samples } " ,
362365 flush = True ,
363366 )
364367 if (
@@ -381,7 +384,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, pass_wo
381384 if failed_decomposition :
382385 print (f"[WARN] { len (failed_decomposition )} models failed to decompose." )
383386
384- return failed_decomposition , final_used_splits_map , max_subgraph_size
387+ return tasks_map , failed_decomposition , max_subgraph_size
385388
386389
387390def print_final_summary (next_round_models , max_subgraph_size ):
@@ -408,17 +411,11 @@ def main(args):
408411
409412 # --- Step 1: Prepare Tasks ---
410413 if current_pass_id == 0 :
411- (
412- tasks_map ,
413- active_models_map_for_save ,
414- max_subgraph_size ,
415- ) = generate_initial_tasks (args )
414+ tasks_map , max_subgraph_size = generate_initial_tasks (args )
416415 else :
417- (
418- tasks_map ,
419- active_models_map_for_save ,
420- max_subgraph_size ,
421- ) = generate_refined_tasks (base_output_dir , current_pass_id )
416+ tasks_map , max_subgraph_size = generate_refined_tasks (
417+ base_output_dir , current_pass_id
418+ )
422419
423420 print (f"[INFO] initial max_subgraph_size: { max_subgraph_size } " )
424421 print (f"[INFO] number of incorrect models: { len (tasks_map )} " )
@@ -442,11 +439,10 @@ def main(args):
442439
443440 # --- Step 3: Decomposition ---
444441 failed_decomposition = []
445- final_used_splits_map = {}
446442 if task_controller .task_scheduler ["run_decomposer" ]:
447443 (
444+ tasks_map ,
448445 failed_decomposition ,
449- final_used_splits_map ,
450446 max_subgraph_size ,
451447 ) = execute_decomposition_phase (
452448 max_subgraph_size , tasks_map , args .framework , pass_work_dir
@@ -455,23 +451,22 @@ def main(args):
455451 # --- Step 4: Testing ---
456452 pass_log_path = os .path .join (pass_work_dir , "batch_test_result.log" )
457453 if task_controller .task_scheduler ["run_evaluation" ]:
458- print ("\n --- Phase 2: Batch Testing ---" )
454+ print ("\n --- Phase 2: Evaluation ---" )
459455 run_evaluation (args .framework , args .test_config , pass_work_dir , pass_log_path )
460456
461457 # --- Step 5: Analysis ---
462458 next_round_models = set ()
463459 if task_controller .task_scheduler ["post_analysis" ]:
464460 print ("\n --- Phase 3: Analysis ---" )
465461 next_round_models = get_incorrect_models (args .tolerance , pass_log_path )
466- print (f"[Result ] Found { len (next_round_models )} incorrect subgraphs." )
462+ print (f"[Analysis ] Found { len (next_round_models )} incorrect subgraphs.\n " )
467463
468464 # --- Step 6: Save State ---
469465 save_decompose_config (
470466 pass_work_dir ,
471467 max_subgraph_size ,
468+ tasks_map ,
472469 next_round_models ,
473- active_models_map_for_save ,
474- final_used_splits_map ,
475470 failed_decomposition ,
476471 )
477472
0 commit comments