@@ -123,17 +123,16 @@ def save_decompose_config(
123123):
124124 """Saves the current state to a JSON file."""
125125
126- active_models_map = {}
127- split_positions_map = {}
126+ tasks_map_copy = {}
128127 for model_name , task_info in tasks_map .items ():
129- active_models_map [model_name ] = task_info ["original_path" ]
130- split_positions_map [model_name ] = task_info ["split_positions" ]
128+ tasks_map_copy [model_name ] = {}
129+ for key in ["original_path" , "split_positions" ]:
130+ tasks_map_copy [model_name ][key ] = task_info [key ]
131131
132132 config = {
133133 "max_subgraph_size" : max_subgraph_size ,
134134 "incorrect_models" : list (incorrect_paths ),
135- "active_models_map" : active_models_map ,
136- "split_positions_map" : split_positions_map ,
135+ "tasks_map" : tasks_map_copy ,
137136 "failed_decomposition_models" : list (failed_decomposition_models ),
138137 }
139138 config_path = get_decompose_config_path (work_dir )
@@ -149,7 +148,7 @@ def get_model_name_with_subgraph_tag(model_path):
149148 return f"{ fields [- 2 ]} _{ fields [- 1 ]} " if re .match (pattern , fields [- 1 ]) else fields [- 1 ]
150149
151150
152- def run_naive_decomposer (
151+ def run_decomposer_for_single_model (
153152 framework : str ,
154153 model_path : str ,
155154 output_dir : str ,
@@ -201,6 +200,32 @@ def run_naive_decomposer(
201200 return True
202201
203202
203+ def run_decomposer_for_multi_models (
204+ framework , tasks_map , decomposed_samples_dir , max_subgraph_size
205+ ):
206+ failed_decomposition = []
207+
208+ for model_name , task_info in tasks_map .items ():
209+ print (f"[Decomposition] max_subgraph_size: { max_subgraph_size } " )
210+ original_path = task_info ["original_path" ]
211+ split_positions = calculate_split_positions_for_subgraph (
212+ task_info ["subgraph_size" ], max_subgraph_size
213+ )
214+ task_info ["split_positions" ] = split_positions
215+
216+ rectified_model_path = get_rectfied_model_path (original_path )
217+ assert os .path .exists (
218+ rectified_model_path
219+ ), f"{ rectified_model_path } does not exist."
220+
221+ success = run_decomposer_for_single_model (
222+ framework , rectified_model_path , decomposed_samples_dir , split_positions
223+ )
224+ if not success :
225+ failed_decomposition .append (rectified_model_path )
226+ return tasks_map , failed_decomposition
227+
228+
204229def run_evaluation (
205230 framework : str , test_cmd_b64 : str , work_dir : str , log_path : str
206231) -> int :
@@ -283,9 +308,8 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
283308 print (f"[Init] Resuming from Pass_{ current_pass_id - 1 } (Dir: { prev_pass_dir } )..." )
284309
285310 prev_config = load_decompose_config (prev_pass_dir )
286- prev_active_models_map = prev_config .get ("active_models_map" , {})
287- prev_split_positions_map = prev_config .get ("split_positions_map" , {})
288311 prev_incorrect_subgraphs = prev_config .get ("incorrect_models" , [])
312+ prev_tasks_map = prev_config .get ("tasks_map" , {})
289313
290314 # Load previous max size as fallback
291315 prev_max_subgraph_size = prev_config .get ("max_subgraph_size" )
@@ -294,20 +318,18 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
294318 if not prev_incorrect_subgraphs :
295319 return {}, max_subgraph_size
296320
297- print ("[Analysis] Refining splits based on previous incorrect models ..." )
298-
299321 tasks_map = {}
300322 for subgraph_path in prev_incorrect_subgraphs :
301323 # Parse model name and subgraph index
302324 model_name_with_subgraph_idx = subgraph_path .rstrip ("/" ).split (os .sep )[- 1 ]
303325 model_name = "_" .join (model_name_with_subgraph_idx .split ("_" )[:- 1 ])
304326 subgraph_idx = int (model_name_with_subgraph_idx .split ("_" )[- 1 ])
305- print ( f"subgraph_path: { subgraph_path } " )
306- print ( f"model_name: { model_name } , subgraph_idx: { subgraph_idx } " )
307- assert model_name in prev_active_models_map
327+
328+ assert model_name in prev_tasks_map
329+ pre_task_for_model = prev_tasks_map [ model_name ]
308330
309331 # Reconstruct previous subgraph size to locate the failing segment
310- prev_split_positions = prev_split_positions_map .get (model_name , [])
332+ prev_split_positions = pre_task_for_model .get ("split_positions" , [])
311333 subgraph_size = reconstruct_subgraph_size (prev_split_positions )
312334 assert subgraph_idx < len (
313335 subgraph_size
@@ -316,49 +338,58 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
316338 if model_name not in tasks_map :
317339 tasks_map [model_name ] = {
318340 "subgraph_path" : subgraph_path ,
319- "original_path" : prev_active_models_map [ model_name ],
341+ "original_path" : pre_task_for_model [ "original_path" ],
320342 "subgraph_size" : subgraph_size [subgraph_idx ],
321343 "split_positions" : set (),
322344 }
323345
324346 return tasks_map , max_subgraph_size
325347
326348
327- def execute_decomposition_phase (max_subgraph_size , tasks_map , framework , pass_work_dir ):
328- """Executes the decomposition phase (Phase 1)."""
329- failed_decomposition = []
349+ def prepare_tasks_and_verify (args , current_pass_id , base_output_dir ):
350+ if current_pass_id == 0 :
351+ tasks_map , max_subgraph_size = generate_initial_tasks (args )
352+ else :
353+ tasks_map , max_subgraph_size = generate_refined_tasks (
354+ base_output_dir , current_pass_id
355+ )
356+
357+ print (f"[INFO] initial max_subgraph_size: { max_subgraph_size } " )
358+ print (f"[INFO] number of incorrect models: { len (tasks_map )} " )
359+ for model_name , task_info in tasks_map .items ():
360+ original_path = task_info ["original_path" ]
361+ print (f"- { original_path } " )
362+
363+ if not tasks_map :
364+ print ("[FINISHED] No models need processing." )
365+ sys .exit (0 )
366+
367+ if max_subgraph_size <= 0 :
368+ print (
369+ f"[FINISHED] Cannot decompose with max_subgraph_size { max_subgraph_size } ."
370+ )
371+ sys .exit (0 )
330372
373+ return tasks_map , max_subgraph_size
374+
375+
376+ def execute_decomposition_phase (max_subgraph_size , tasks_map , framework , workspace ):
377+ """Executes the decomposition phase."""
378+
379+ failed_decomposition = []
331380 need_decompose = True if len (tasks_map ) > 0 else False
332- if need_decompose :
333- print ("\n --- Phase 1: Decomposition ---" , flush = True )
334381
335382 while need_decompose :
336383 decomposed_samples_dir = os .path .join (
337- pass_work_dir , "samples" if framework == "torch" else "paddle_samples"
384+ workspace , "samples" if framework == "torch" else "paddle_samples"
338385 )
339386 if not os .path .exists (decomposed_samples_dir ):
340387 os .makedirs (decomposed_samples_dir , exist_ok = True )
341- print (f"decomposed_samples_dir: { decomposed_samples_dir } " )
342-
343- for model_name , task_info in tasks_map .items ():
344- print (f"[Decomposition] max_subgraph_size: { max_subgraph_size } " )
345- original_path = task_info ["original_path" ]
346- split_positions = calculate_split_positions_for_subgraph (
347- task_info ["subgraph_size" ], max_subgraph_size
348- )
349- task_info ["split_positions" ] = split_positions
350-
351- rectified_model_path = get_rectfied_model_path (original_path )
352- assert os .path .exists (
353- rectified_model_path
354- ), f"{ rectified_model_path } does not exist."
355-
356- success = run_naive_decomposer (
357- framework , rectified_model_path , decomposed_samples_dir , split_positions
358- )
359- if not success :
360- failed_decomposition .append (rectified_model_path )
388+ print (f"[Decomposition] decomposed_samples_dir: { decomposed_samples_dir } " )
361389
390+ tasks_map , failed_decomposition = run_decomposer_for_multi_models (
391+ framework , tasks_map , decomposed_samples_dir , max_subgraph_size
392+ )
362393 num_decomposed_samples = count_samples (decomposed_samples_dir )
363394 print (
364395 f"[Decomposition] number of graphs: { len (tasks_map )} -> { num_decomposed_samples } " ,
@@ -387,8 +418,8 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, pass_wo
387418 return tasks_map , failed_decomposition , max_subgraph_size
388419
389420
390- def print_final_summary (next_round_models , max_subgraph_size ):
391- """Prints the final suggestion/result."""
421+ def print_summary_and_suggestion (next_round_models , max_subgraph_size ):
422+ """Print suggestion/result."""
392423 print ("\n " + "=" * 80 )
393424 if next_round_models and max_subgraph_size > 1 :
394425 print (f">>> [SUGGESTION] Issues remain (Count: { len (next_round_models )} )." )
@@ -409,59 +440,48 @@ def main(args):
409440 print (f" GraphNet Auto-Debugger | ROUND: PASS_{ current_pass_id } " )
410441 print ("=" * 80 )
411442
412- # --- Step 1: Prepare Tasks ---
413- if current_pass_id == 0 :
414- tasks_map , max_subgraph_size = generate_initial_tasks (args )
415- else :
416- tasks_map , max_subgraph_size = generate_refined_tasks (
417- base_output_dir , current_pass_id
418- )
419-
420- print (f"[INFO] initial max_subgraph_size: { max_subgraph_size } " )
421- print (f"[INFO] number of incorrect models: { len (tasks_map )} " )
422- for model_name , task_info in tasks_map .items ():
423- original_path = task_info ["original_path" ]
424- print (f"- { original_path } " )
425-
426- if not tasks_map :
427- print ("[FINISHED] No models need processing." )
428- sys .exit (0 )
429- if max_subgraph_size <= 0 :
430- print (
431- f"[FINISHED] Cannot decompose with max_subgraph_size { max_subgraph_size } ."
432- )
433- sys .exit (0 )
434-
435- # --- Step 2: Prepare Workspace ---
443+ # --- Step 1: Prepare Tasks and Workspace ---
444+ tasks_map , max_subgraph_size = prepare_tasks_and_verify (
445+ args , current_pass_id , base_output_dir
446+ )
436447 pass_work_dir = get_decompose_workspace_path (base_output_dir , current_pass_id )
437448 if not os .path .exists (pass_work_dir ):
438449 os .makedirs (pass_work_dir , exist_ok = True )
439450
440- # --- Step 3 : Decomposition ---
451+ # --- Step 2 : Decomposition ---
441452 failed_decomposition = []
442453 if task_controller .task_scheduler ["run_decomposer" ]:
454+ print ("\n --- Phase 1: Decomposition ---" , flush = True )
443455 (
444456 tasks_map ,
445457 failed_decomposition ,
446458 max_subgraph_size ,
447459 ) = execute_decomposition_phase (
448460 max_subgraph_size , tasks_map , args .framework , pass_work_dir
449461 )
462+ else :
463+ config = load_decompose_config (pass_work_dir )
464+ max_subgraph_size = config ["max_subgraph_size" ]
465+ failed_decomposition = config ["failed_decomposition_models" ]
466+ tasks_map = config .get ("tasks_map" , {})
450467
451- # --- Step 4: Testing ---
468+ # --- Step 3: Evaluation ---
452469 pass_log_path = os .path .join (pass_work_dir , "batch_test_result.log" )
453470 if task_controller .task_scheduler ["run_evaluation" ]:
454471 print ("\n --- Phase 2: Evaluation ---" )
455472 run_evaluation (args .framework , args .test_config , pass_work_dir , pass_log_path )
456473
457- # --- Step 5 : Analysis ---
474+ # --- Step 4 : Analysis ---
458475 next_round_models = set ()
459476 if task_controller .task_scheduler ["post_analysis" ]:
460477 print ("\n --- Phase 3: Analysis ---" )
461478 next_round_models = get_incorrect_models (args .tolerance , pass_log_path )
462479 print (f"[Analysis] Found { len (next_round_models )} incorrect subgraphs.\n " )
463480
464- # --- Step 6: Save State ---
481+ print_summary_and_suggestion (next_round_models , max_subgraph_size )
482+ print ()
483+
484+ # --- Step 5: Save States ---
465485 save_decompose_config (
466486 pass_work_dir ,
467487 max_subgraph_size ,
@@ -470,8 +490,6 @@ def main(args):
470490 failed_decomposition ,
471491 )
472492
473- print_final_summary (next_round_models , max_subgraph_size )
474-
475493
476494if __name__ == "__main__" :
477495 parser = argparse .ArgumentParser ()
0 commit comments