@@ -296,97 +296,168 @@ def calculate_split_positions_for_subgraph(subgraph_size):
296296 return split_positions
297297
298298
299- def main ( args ):
300- task_controller = TaskController ( args )
301- base_output_dir = task_controller . root_output_dir
302- current_pass_id = task_controller . current_pass_id
299+ def generate_initial_tasks ( log_file , tolerance , max_subgraph_size ):
300+ """Generates tasks for Pass 0 based on the initial log file."""
301+ print ( f"[Init] Pass 0: Reading from log file: { log_file } " )
302+ initial_failures = get_incorrect_models ( tolerance , log_file )
303303
304- print ("=" * 80 )
305- print (f" GraphNet Auto-Debugger | ROUND: PASS_{ current_pass_id } " )
306- print ("=" * 80 )
304+ # Dynamic generation based on step size
305+ initial_splits = list (range (0 , kMaxGraphSize + 1 , max_subgraph_size ))
307306
308307 tasks_map = {}
309308 active_models_map_for_save = {}
310309
311- # Initialize using the argument passed from bash
312- max_subgraph_size = args .max_subgraph_size
310+ for path in initial_failures :
311+ name = os .path .basename (path .rstrip ("/" ))
312+ active_models_map_for_save [name ] = path
313+ tasks_map [name ] = {
314+ "original_path" : path ,
315+ "split_positions" : set (initial_splits ),
316+ }
317+ return tasks_map , active_models_map_for_save , max_subgraph_size
313318
314- if current_pass_id == 0 :
315- print (f"[Init] Pass 0: Reading from log file: { args .log_file } " )
316- initial_failures = get_incorrect_models (args .tolerance , args .log_file )
317-
318- # Dynamic generation based on step size (args.max_subgraph_size)
319- initial_splits = list (range (0 , kMaxGraphSize + 1 , max_subgraph_size ))
320-
321- for path in initial_failures :
322- name = os .path .basename (path .rstrip ("/" ))
323- active_models_map_for_save [name ] = path
324- tasks_map [name ] = {
325- "original_path" : path ,
326- "split_positions" : set (initial_splits ),
327- }
328- else :
329- prev_pass_dir = get_decompose_workspace_path (
330- base_output_dir , current_pass_id - 1
319+
320+ def generate_refined_tasks (base_output_dir , current_pass_id , default_max_size ):
321+ """Generates tasks for Pass > 0 based on previous pass results."""
322+ prev_pass_dir = get_decompose_workspace_path (base_output_dir , current_pass_id - 1 )
323+ print (f"[Init] Resuming from Pass_{ current_pass_id - 1 } (Dir: { prev_pass_dir } )..." )
324+
325+ prev_config = load_decompose_config (prev_pass_dir )
326+ prev_active_models_map = prev_config .get ("active_models_map" , {})
327+ prev_used_splits = prev_config .get ("split_positions_map" , {})
328+ prev_incorrect_subgraphs = prev_config .get ("incorrect_models" , [])
329+
330+ # Load previous max size as fallback
331+ max_subgraph_size = prev_config .get ("max_subgraph_size" , default_max_size )
332+
333+ if not prev_incorrect_subgraphs :
334+ return {}, {}, max_subgraph_size
335+
336+ print ("[Analysis] Refining splits based on previous incorrect models ..." )
337+
338+ tasks_map = {}
339+ active_models_map_for_save = {}
340+
341+ for subgraph_path in prev_incorrect_subgraphs :
342+ # Parse model name and subgraph index
343+ model_name_with_subgraph_idx = subgraph_path .rstrip ("/" ).split (os .sep )[- 1 ]
344+ model_name = "_" .join (model_name_with_subgraph_idx .split ("_" )[:- 1 ])
345+ subgraph_idx = int (model_name_with_subgraph_idx .split ("_" )[- 1 ])
346+
347+ if model_name not in prev_active_models_map :
348+ continue
349+
350+ active_models_map_for_save [model_name ] = prev_active_models_map [model_name ]
351+
352+ # Reconstruct previous subgraph size to locate the failing segment
353+ prev_split_positions = sorted (prev_used_splits .get (model_name , []))
354+ subgraph_size = reconstruct_subgraph_size (prev_split_positions )
355+
356+ if subgraph_idx >= len (subgraph_size ):
357+ print (
358+ f"[WARN] Subgraph index { subgraph_idx } out of bounds for { model_name } "
359+ )
360+ continue
361+
362+ split_positions = calculate_split_positions_for_subgraph (
363+ subgraph_size [subgraph_idx ]
331364 )
332- print (
333- f"[Init] Resuming from Pass_{ current_pass_id - 1 } (Dir: { prev_pass_dir } )..."
365+
366+ if model_name not in tasks_map :
367+ tasks_map [model_name ] = {
368+ "subgraph_path" : subgraph_path ,
369+ "original_path" : prev_active_models_map [model_name ],
370+ "subgraph_size" : subgraph_size [subgraph_idx ],
371+ "split_positions" : split_positions ,
372+ }
373+
374+ return tasks_map , active_models_map_for_save , max_subgraph_size
375+
376+
377+ def execute_decomposition_phase (tasks_map , framework , pass_work_dir , should_run ):
378+ """Executes the decomposition phase (Phase 1)."""
379+ failed_decomposition = []
380+ final_used_splits_map = {}
381+
382+ if not should_run or not tasks_map :
383+ return failed_decomposition , final_used_splits_map
384+
385+ print ("\n --- Phase 1: Decomposition ---" , flush = True )
386+
387+ decomposed_samples_dir = os .path .join (
388+ pass_work_dir , "samples" if framework == "torch" else "paddle_samples"
389+ )
390+ os .makedirs (decomposed_samples_dir , exist_ok = True )
391+ print (f"decomposed_samples_dir: { decomposed_samples_dir } " )
392+
393+ for model_name , task_info in tasks_map .items ():
394+ original_path = task_info ["original_path" ]
395+ split_positions = sorted (list (task_info ["split_positions" ]))
396+ final_used_splits_map [model_name ] = split_positions
397+
398+ rectified_model_path = get_rectfied_model_path (original_path )
399+
400+ success = run_decomposer (
401+ framework , rectified_model_path , decomposed_samples_dir , split_positions
334402 )
403+ if not success :
404+ failed_decomposition .append (rectified_model_path )
335405
336- prev_config = load_decompose_config (prev_pass_dir )
337- prev_active_models_map = prev_config .get ("active_models_map" , {})
338- prev_used_splits = prev_config .get ("split_positions_map" , {})
339- prev_incorrect_subgraphs = prev_config .get ("incorrect_models" , [])
406+ num_samples = count_samples (decomposed_samples_dir )
407+ print (f"- number of graphs: { len (tasks_map )} -> { num_samples } " , flush = True )
340408
341- # Load previous max size as fallback for calculation
342- prev_max_size = prev_config .get ("max_subgraph_size" , args .max_subgraph_size )
343- max_subgraph_size = prev_max_size
409+ if failed_decomposition :
410+ print (f"[WARN] { len (failed_decomposition )} models failed to decompose." )
344411
345- if not prev_incorrect_subgraphs :
346- print ("[FINISHED] Debugging completed." )
347- sys .exit (0 )
412+ return failed_decomposition , final_used_splits_map
348413
349- print ("[Analysis] Refining splits based on previous incorrect models ..." )
350414
351- for subgraph_path in prev_incorrect_subgraphs :
352- print (f"- subgraph_path: { subgraph_path } " )
353- model_name_with_subgraph_idx = subgraph_path .rstrip ("/" ).split (os .sep )[- 1 ]
354- model_name = "_" .join (model_name_with_subgraph_idx .split ("_" )[:- 1 ])
355- subgraph_idx = int (model_name_with_subgraph_idx .split ("_" )[- 1 ])
356- print (f"- model_name: { model_name } , subgraph_idx: { subgraph_idx } " )
415+ def print_final_summary (next_round_models , real_subgraph_size ):
416+ """Prints the final suggestion/result."""
417+ print ("\n " + "=" * 80 )
418+ if next_round_models and real_subgraph_size > 1 :
419+ print (f">>> [SUGGESTION] Issues remain (Count: { len (next_round_models )} )." )
420+ print (">>> Please start next round decomposition test (Run this script again)." )
421+ elif next_round_models and real_subgraph_size <= 1 :
422+ print (">>> [FAILURE] Minimal granularity reached, but errors persist." )
423+ else :
424+ print (">>> [SUCCESS] Debugging converged." )
425+ print ("=" * 80 )
357426
358- assert model_name in prev_active_models_map
359- active_models_map_for_save [model_name ] = prev_active_models_map [model_name ]
360427
361- # Reconstruct previous subgraph size to locate the failing segment
362- prev_split_positions = sorted (prev_used_splits .get (model_name , []))
363- subgraph_size = reconstruct_subgraph_size (prev_split_positions )
364- assert subgraph_idx < len (
365- subgraph_size
366- ), f"subgraph_idx { subgraph_idx } is out of bounds for { model_name } (previous split_positions: { prev_split_positions } )"
428+ def main (args ):
429+ task_controller = TaskController (args )
430+ base_output_dir = task_controller .root_output_dir
431+ current_pass_id = task_controller .current_pass_id
432+
433+ print ("=" * 80 )
434+ print (f" GraphNet Auto-Debugger | ROUND: PASS_{ current_pass_id } " )
435+ print ("=" * 80 )
367436
368- split_positions = calculate_split_positions_for_subgraph (
369- subgraph_size [subgraph_idx ]
370- )
371- if model_name not in tasks_map :
372- tasks_map [model_name ] = {
373- "subgraph_path" : subgraph_path ,
374- "original_path" : prev_active_models_map [model_name ],
375- "subgraph_size" : subgraph_size [subgraph_idx ],
376- "split_positions" : split_positions ,
377- }
378- else :
379- continue
380-
381- # Recalculate based on current map to ensure log accuracy
437+ # --- Step 1: Prepare Tasks ---
438+ if current_pass_id == 0 :
439+ (
440+ tasks_map ,
441+ active_models_map_for_save ,
442+ max_subgraph_size ,
443+ ) = generate_initial_tasks (
444+ args .log_file , args .tolerance , args .max_subgraph_size
445+ )
446+ else :
447+ (
448+ tasks_map ,
449+ active_models_map_for_save ,
450+ max_subgraph_size ,
451+ ) = generate_refined_tasks (
452+ base_output_dir , current_pass_id , args .max_subgraph_size
453+ )
454+
455+ # Recalculate size for logging
382456 real_subgraph_size = calculate_current_subgraph_size (tasks_map , max_subgraph_size )
383457 print (f"[INFO] Current Subgraph Size: { real_subgraph_size } " )
384458 print (f"[INFO] Models to Process: { len (tasks_map )} " )
385- for model_name , task_info in tasks_map .items ():
386- original_path = task_info ["original_path" ]
387- print (f"- { original_path } " )
388459
389- if not tasks_map :
460+ if not tasks_map and current_pass_id > 0 :
390461 print ("[FINISHED] No models need processing." )
391462 sys .exit (0 )
392463
@@ -396,57 +467,17 @@ def main(args):
396467 os .makedirs (pass_work_dir , exist_ok = True )
397468
398469 # --- Step 3: Decomposition ---
399- need_decompose = (
400- True
401- if task_controller .task_scheduler ["run_decomposer" ] and len (tasks_map ) > 0
402- else False
470+ failed_decomposition , final_used_splits_map = execute_decomposition_phase (
471+ tasks_map ,
472+ args .framework ,
473+ pass_work_dir ,
474+ task_controller .task_scheduler ["run_decomposer" ],
403475 )
404- if need_decompose :
405- print ("\n --- Phase 1: Decomposition ---" , flush = True )
406-
407- failed_decomposition = []
408- final_used_splits_map = {}
409- if need_decompose :
410- decomposed_samples_dir = os .path .join (
411- pass_work_dir , "samples" if args .framework == "torch" else "paddle_samples"
412- )
413- os .makedirs (decomposed_samples_dir , exist_ok = True )
414- print (f"decomposed_samples_dir: { decomposed_samples_dir } " )
415-
416- for model_name , task_info in tasks_map .items ():
417- original_path = task_info ["original_path" ]
418- split_positions = sorted (list (task_info ["split_positions" ]))
419-
420- final_used_splits_map [model_name ] = split_positions
421-
422- rectied_model_path = get_rectfied_model_path (original_path )
423- print (f"original_path: { original_path } " )
424- print (f"rectied_model_path: { rectied_model_path } " )
425- assert os .path .exists (
426- rectied_model_path
427- ), f"{ rectied_model_path } does not exist."
428-
429- success = run_decomposer (
430- args .framework ,
431- rectied_model_path ,
432- decomposed_samples_dir ,
433- split_positions ,
434- )
435- if not success :
436- failed_decomposition .append (rectied_model_path )
437-
438- num_decomposed_samples = count_samples (decomposed_samples_dir )
439- print (
440- f"- number of graphs: { len (tasks_map )} -> { num_decomposed_samples } " ,
441- flush = True ,
442- )
443- if failed_decomposition :
444- print (f"[WARN] { len (failed_decomposition )} models failed to decompose." )
445476
446477 # --- Step 4: Testing ---
478+ pass_log_path = os .path .join (pass_work_dir , "batch_test_result.log" )
447479 if task_controller .task_scheduler ["run_evaluation" ]:
448480 print ("\n --- Phase 2: Batch Testing ---" )
449- pass_log_path = os .path .join (pass_work_dir , "batch_test_result.log" )
450481 run_evaluation (args .framework , args .test_config , pass_work_dir , pass_log_path )
451482
452483 # --- Step 5: Analysis ---
@@ -466,15 +497,7 @@ def main(args):
466497 failed_decomposition ,
467498 )
468499
469- print ("\n " + "=" * 80 )
470- if next_round_models and real_subgraph_size > 1 :
471- print (f">>> [SUGGESTION] Issues remain (Count: { len (next_round_models )} )." )
472- print (">>> Please start next round decomposition test (Run this script again)." )
473- elif next_round_models and real_subgraph_size <= 1 :
474- print (">>> [FAILURE] Minimal granularity reached, but errors persist." )
475- else :
476- print (">>> [SUCCESS] Debugging converged." )
477- print ("=" * 80 )
500+ print_final_summary (next_round_models , real_subgraph_size )
478501
479502
480503if __name__ == "__main__" :
0 commit comments