@@ -139,6 +139,19 @@ def load(self, work_dir):
139139 def get_config_path (self , work_dir ) -> str :
140140 return os .path .join (work_dir , "decompose_config.json" )
141141
142+ def update_running_states (self , pass_id , ** kwargs ):
143+ pass_key = get_pass_name (pass_id )
144+ if self .running_states .get (pass_key , None ) is None :
145+ self .running_states [pass_key ] = {}
146+
147+ for key , value in kwargs .items ():
148+ assert key in [
149+ "num_incorrect_models" ,
150+ "incorrect_models" ,
151+ "failed_decomposition_models" ,
152+ ]
153+ self .running_states [pass_key ][key ] = value
154+
142155
143156def get_rectfied_model_path (model_path ):
144157 graphnet_root = path_utils .get_graphnet_root ()
@@ -268,11 +281,10 @@ def run_evaluation(
268281
269282def reconstruct_subgraph_size (split_positions : List [int ]) -> List [list ]:
270283 """Reconstructs subgraph size based on sorted split positions."""
271- deduplicated_splits = list ( dict . fromkeys (split_positions ))
284+ deduplicated_splits = sorted ( set (split_positions ))
272285
273286 subgraph_size = [
274- [deduplicated_splits [i ], deduplicated_splits [i + 1 ]]
275- for i in range (len (deduplicated_splits ) - 1 )
287+ deduplicated_splits [i : i + 2 ] for i in range (len (deduplicated_splits ) - 1 )
276288 ]
277289 return subgraph_size
278290
@@ -328,7 +340,7 @@ def extract_model_name_and_subgraph_idx(subgraph_path):
328340 return model_name , subgraph_idx
329341
330342
331- def generate_refined_tasks (base_output_dir , current_pass_id ):
343+ def generate_successor_tasks (base_output_dir , current_pass_id ):
332344 """Generates tasks for Pass > 0 based on previous pass results."""
333345 prev_pass_dir = get_decompose_workspace_path (base_output_dir , current_pass_id - 1 )
334346 print (f"[Init] Resuming from Pass_{ current_pass_id - 1 } (Dir: { prev_pass_dir } )..." )
@@ -377,7 +389,7 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
377389 if current_pass_id == 0 :
378390 tasks_map , max_subgraph_size , running_states = generate_initial_tasks (args )
379391 else :
380- tasks_map , max_subgraph_size , running_states = generate_refined_tasks (
392+ tasks_map , max_subgraph_size , running_states = generate_successor_tasks (
381393 base_output_dir , current_pass_id
382394 )
383395
@@ -435,15 +447,13 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
435447 os .makedirs (decomposed_samples_dir , exist_ok = True )
436448 max_subgraph_size = max (1 , max_subgraph_size // 2 )
437449 for model_name , task_info in tasks_map .items ():
438- splits = task_info ["split_positions" ]
439- if not splits or len (splits ) < 2 :
450+ split_positions = task_info ["split_positions" ]
451+ if not split_positions or len (split_positions ) < 2 :
440452 continue
441- start_pos = splits [0 ]
442- first_segment_end = splits [1 ]
443- new_splits = calculate_split_positions_for_subgraph (
444- [start_pos , first_segment_end ], max_subgraph_size
453+ new_split_positions = calculate_split_positions_for_subgraph (
454+ split_positions [0 :2 ], max_subgraph_size
445455 )
446- task_info ["split_positions" ] = new_splits
456+ task_info ["split_positions" ] = new_split_positions
447457 else :
448458 need_decompose = False
449459 print ()
@@ -454,6 +464,15 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
454464 return tasks_map , failed_decomposition , max_subgraph_size
455465
456466
467+ def count_unique_original_models (incorrect_models ):
468+ original_model_paths = set (
469+ model_name
470+ for subgraph_path in incorrect_models
471+ for model_name , _ in [extract_model_name_and_subgraph_idx (subgraph_path )]
472+ )
473+ return len (original_model_paths )
474+
475+
457476def print_summary_and_suggestion (next_round_models , max_subgraph_size ):
458477 """Print suggestion/result."""
459478 print ("\n " + "=" * 80 )
@@ -480,9 +499,14 @@ def main(args):
480499 tasks_map , max_subgraph_size , running_states = prepare_tasks_and_verify (
481500 args , current_pass_id , base_output_dir
482501 )
483- pass_work_dir = get_decompose_workspace_path (base_output_dir , current_pass_id )
484- if not os .path .exists (pass_work_dir ):
485- os .makedirs (pass_work_dir , exist_ok = True )
502+ decompose_config = DecomposeConfig (
503+ max_subgraph_size = max_subgraph_size ,
504+ tasks_map = tasks_map ,
505+ running_states = running_states ,
506+ )
507+ work_dir = get_decompose_workspace_path (base_output_dir , current_pass_id )
508+ if not os .path .exists (work_dir ):
509+ os .makedirs (work_dir , exist_ok = True )
486510
487511 # --- Step 2: Decomposition ---
488512 if task_controller .task_scheduler ["run_decomposer" ]:
@@ -492,63 +516,45 @@ def main(args):
492516 failed_decomposition ,
493517 max_subgraph_size ,
494518 ) = execute_decomposition_phase (
495- max_subgraph_size , tasks_map , args .framework , pass_work_dir
519+ max_subgraph_size , tasks_map , args .framework , work_dir
520+ )
521+ decompose_config .update_running_states (
522+ current_pass_id , failed_decomposition_models = list (failed_decomposition )
496523 )
497- running_states .get (f"pass_{ current_pass_id } " , {})[
498- "failed_decomposition_models"
499- ] = list (failed_decomposition )
500524 else :
501525 print ("\n --- Phase 1: Decomposition (skipped) ---" , flush = True )
502- config = DecomposeConfig .load (pass_work_dir )
503- max_subgraph_size = config .max_subgraph_size
504- tasks_map = config .tasks_map
505- running_states = config .running_states
526+ decompose_config = DecomposeConfig .load (work_dir )
506527
507528 # --- Step 3: Evaluation ---
508- pass_log_path = os .path .join (pass_work_dir , "batch_test_result.log " )
529+ log_path = os .path .join (work_dir , f"log_ { task_controller . test_module_name } .txt " )
509530 if task_controller .task_scheduler ["run_evaluation" ]:
510531 print (f"\n --- Phase 2: Evaluation ({ task_controller .test_module_name } ) ---" )
511- run_evaluation (args .framework , args .test_config , pass_work_dir , pass_log_path )
532+ run_evaluation (args .framework , args .test_config , work_dir , log_path )
512533
513534 # --- Step 4: Analysis ---
514- next_round_models = set ()
535+ next_pass_incorrect_models = set ()
515536 if task_controller .task_scheduler ["post_analysis" ]:
516537 tolerance = (
517538 args .tolerance [0 ] if isinstance (args .tolerance , list ) else args .tolerance
518539 )
519540 print (f"\n --- Phase 3: Analysis (torlance={ tolerance } ) ---" )
520- next_round_models = sorted (get_incorrect_models (tolerance , pass_log_path ))
521- original_model_paths = set (
522- [
523- model_name
524- for subgraph_path in next_round_models
525- for model_name , _ in [
526- extract_model_name_and_subgraph_idx (subgraph_path )
527- ]
528- ]
541+ next_pass_incorrect_models = sorted (get_incorrect_models (tolerance , log_path ))
542+ num_original_models = count_unique_original_models (next_pass_incorrect_models )
543+ decompose_config .update_running_states (
544+ current_pass_id + 1 ,
545+ num_incorrect_models = num_original_models ,
546+ incorrect_models = list (next_pass_incorrect_models ),
529547 )
530-
531- running_states [f"pass_{ current_pass_id + 1 } " ] = {
532- "num_incorrect_models" : len (original_model_paths ),
533- "incorrect_models" : list (next_round_models ),
534- }
535-
536548 print (
537- f"[Analysis] Found { len (next_round_models )} incorrect subgraphs ({ len ( original_model_paths ) } original models)."
549+ f"[Analysis] Found { len (next_pass_incorrect_models )} incorrect subgraphs ({ num_original_models } original models)."
538550 )
539- for idx , model_path in enumerate (next_round_models ):
551+ for idx , model_path in enumerate (next_pass_incorrect_models ):
540552 print (f"- [{ idx } ] { model_path } " )
541-
542- print_summary_and_suggestion (next_round_models , max_subgraph_size )
553+ print_summary_and_suggestion (next_pass_incorrect_models , max_subgraph_size )
543554
544555 # --- Step 5: Save States ---
545- config = DecomposeConfig (
546- max_subgraph_size = max_subgraph_size ,
547- incorrect_models = list (next_round_models ),
548- tasks_map = tasks_map ,
549- running_states = running_states ,
550- )
551- config .save (pass_work_dir )
556+ decompose_config .incorrect_models = list (next_pass_incorrect_models )
557+ decompose_config .save (work_dir )
552558
553559
554560if __name__ == "__main__" :
0 commit comments