@@ -85,29 +85,40 @@ def count_samples(samples_dir):
8585 return num_samples
8686
8787
88- def load_prev_config (pass_id : int , output_dir : str ) -> Dict [str , Any ]:
88+ def get_decompose_config_path (output_dir : str ) -> str :
89+ """Returns the full path to the decompose configuration file."""
90+ return os .path .join (output_dir , "decompose_config.json" )
91+
92+
93+ def load_decompose_config (pass_id : int , output_dir : str ) -> Dict [str , Any ]:
8994 """Loads the configuration file from the previous pass."""
9095 prev_dir = os .path .join (output_dir , f"pass_{ pass_id - 1 } " )
91- config_path = os .path .join (prev_dir , "decompose_config.json" )
96+ config_path = get_decompose_config_path (prev_dir )
97+
9298 if not os .path .exists (config_path ):
9399 raise FileNotFoundError (f"Missing configuration file: { config_path } " )
94100 with open (config_path , "r" ) as f :
95101 return json .load (f )
96102
97103
98- def save_current_config (
104+ def save_decompose_config (
99105 work_dir : str ,
100- current_max_size : int ,
101- incorrect_models : Union [List [str ], Set [str ]],
102- failed_models : List [str ],
106+ max_subgraph_size : int ,
107+ incorrect_paths : Union [List [str ], Set [str ]],
108+ active_models_map : Dict [str , str ],
109+ split_positions_map : Dict [str , List [int ]],
110+ failed_decomposition_models : Union [List [str ], Set [str ]],
103111):
104- """Saves the current state."""
112+ """Saves the current state to a JSON file ."""
105113 config = {
106- "current_max_subgraph_size" : current_max_size ,
107- "incorrect_models" : list (incorrect_models ),
108- "failed_extraction_models" : list (failed_models ),
114+ "max_subgraph_size" : max_subgraph_size ,
115+ "incorrect_models" : list (incorrect_paths ),
116+ "active_models_map" : active_models_map ,
117+ "split_positions_map" : split_positions_map ,
118+ "failed_decomposition_models" : list (failed_decomposition_models ),
109119 }
110- config_path = os .path .join (work_dir , "decompose_config.json" )
120+ config_path = get_decompose_config_path (work_dir )
121+
111122 with open (config_path , "w" ) as f :
112123 json .dump (config , f , indent = 4 )
113124 print (f"[INFO] State saved to: { config_path } " )
@@ -123,15 +134,10 @@ def run_decomposer(
123134 framework : str ,
124135 model_path : str ,
125136 output_dir : str ,
126- max_subgraph_size : int ,
137+ split_positions : List [ int ] ,
127138) -> bool :
128139 """Decomposes a single model."""
129140
130- upper_bound = 4096
131- split_positions = list (
132- range (max_subgraph_size , upper_bound + max_subgraph_size , max_subgraph_size )
133- )
134-
135141 graphnet_root = path_utils .get_graphnet_root ()
136142 model_name = get_model_name_with_subgraph_tag (model_path )
137143 decorator_config = {
@@ -142,7 +148,7 @@ def run_decomposer(
142148 "custom_extractor_config" : {
143149 "output_dir" : output_dir ,
144150 "split_positions" : split_positions ,
145- "group_head_and_tail" : True ,
151+ "group_head_and_tail" : False ,
146152 "chain_style" : False ,
147153 },
148154 },
@@ -151,9 +157,9 @@ def run_decomposer(
151157 json .dumps (decorator_config ).encode ()
152158 ).decode ()
153159
154- print (
155- f"- [Decomposing] { model_name } (max_subgraph_size= { max_subgraph_size } , split_positions= { split_positions } )"
156- )
160+ print (f"[Decomposing] { model_path } " )
161+ print ( f"[Strategy] split_positions: { split_positions } " )
162+
157163 cmd = [
158164 sys .executable ,
159165 "-m" ,
@@ -196,8 +202,8 @@ def run_evaluation(
196202 for item in (f"--{ key } " , str (value ))
197203 ]
198204
199- print (f" [Batch Testing] Logging to: { log_path } " )
200- print (f" [Command] { ' ' .join (cmd )} " )
205+ print (f"[Batch Testing] Logging to: { log_path } " )
206+ print (f"[Command] { ' ' .join (cmd )} " )
201207
202208 os .makedirs (os .path .dirname (log_path ), exist_ok = True )
203209 with open (log_path , "w" ) as f :
@@ -209,6 +215,47 @@ def run_evaluation(
209215 sys .exit (proc .returncode )
210216
211217
218+ def reconstruct_subgraph_size (split_positions : List [int ]) -> List [tuple ]:
219+ """Reconstructs subgraph size based on sorted split positions."""
220+ full_splits = sorted (list (set (split_positions )))
221+
222+ subgraph_size = []
223+ # Needs at least 2 points to form an subgraph size
224+ if len (full_splits ) < 2 :
225+ return []
226+
227+ for i in range (len (full_splits ) - 1 ):
228+ subgraph_size .append ((full_splits [i ], full_splits [i + 1 ]))
229+
230+ return subgraph_size
231+
232+
233+ def calculate_current_subgraph_size (
234+ tasks_map : Dict [str , Dict ], fallback_size : int
235+ ) -> int :
236+ """Calculates the current subgraph size from generated tasks."""
237+ current_subgraph_size = float ("inf" )
238+ found_splits = False
239+
240+ for _ , info in tasks_map .items ():
241+ splits = sorted (list (info ["split_positions" ]))
242+
243+ if len (splits ) < 2 :
244+ continue
245+
246+ found_splits = True
247+ for i in range (len (splits ) - 1 ):
248+ diff = splits [i + 1 ] - splits [i ]
249+ if diff > 0 :
250+ current_subgraph_size = min (current_subgraph_size , diff )
251+
252+ return (
253+ int (current_subgraph_size )
254+ if found_splits and current_subgraph_size != float ("inf" )
255+ else fallback_size
256+ )
257+
258+
212259def main (args ):
213260 task_controller = TaskController (args )
214261 base_output_dir = task_controller .root_output_dir
@@ -218,28 +265,119 @@ def main(args):
218265 print (f" GraphNet Auto-Debugger | ROUND: PASS_{ current_pass_id } " )
219266 print ("=" * 80 )
220267
221- # --- Step 1: Initialize State ---
222- target_models = []
223- current_max_size = args .max_subgraph_size
268+ tasks_map = {}
269+ active_models_map_for_save = {}
270+ kMaxGraphSize = 4096
271+
272+ # Initialize using the argument passed from bash
273+ max_subgraph_size = args .max_subgraph_size
274+
224275 if current_pass_id == 0 :
225276 print (f"[Init] Pass 0: Reading from log file: { args .log_file } " )
226- current_max_size = args .max_subgraph_size
227- target_models = get_incorrect_models (args .tolerance , args .log_file )
228- else :
229- print (f"[Init] Resuming from Pass { current_pass_id - 1 } ..." )
230- prev_config = load_prev_config (current_pass_id , base_output_dir )
231- target_models = prev_config .get ("incorrect_models" , [])
277+ initial_failures = get_incorrect_models (args .tolerance , args .log_file )
232278
233- prev_size = prev_config . get ( "current_max_subgraph_size" , 2048 )
234- current_max_size = max ( 1 , prev_size // 2 )
279+ # Dynamic generation based on step size (args.max_subgraph_size )
280+ initial_splits = list ( range ( 0 , kMaxGraphSize + 1 , max_subgraph_size ) )
235281
236- print (f"[INFO] current max_subgraph_size: { current_max_size } " )
237- print (f"[INFO] number of incorrect models: { len (target_models )} " )
238- for model_path in target_models :
239- print (f"- { model_path } " )
282+ for path in initial_failures :
283+ name = os .path .basename (path .rstrip ("/" ))
284+ active_models_map_for_save [name ] = path
285+ tasks_map [name ] = {
286+ "original_path" : path ,
287+ "split_positions" : set (initial_splits ),
288+ }
289+ else :
290+ prev_pass_dir = os .path .join (base_output_dir , f"pass_{ current_pass_id - 1 } " )
291+ print (
292+ f"[Init] Resuming from Pass { current_pass_id - 1 } (Dir: { prev_pass_dir } )..."
293+ )
240294
241- if not target_models :
242- print (f"[FINISHED] Debugging completed." )
295+ prev_config = load_decompose_config (current_pass_id , base_output_dir )
296+ prev_map = prev_config .get ("active_models_map" , {})
297+
298+ prev_used_splits = prev_config .get ("split_positions_map" , {})
299+ prev_incorrect_subgraphs = prev_config .get ("incorrect_models" , [])
300+
301+ # Load previous max size as fallback for calculation
302+ prev_max_size = prev_config .get ("max_subgraph_size" , args .max_subgraph_size )
303+ max_subgraph_size = prev_max_size
304+
305+ if not prev_incorrect_subgraphs :
306+ print (f"[FINISHED] Debugging completed." )
307+ sys .exit (0 )
308+
309+ print (f"[Analysis] Refining splits based on failures..." )
310+
311+ for sub_path in prev_incorrect_subgraphs :
312+ parts = sub_path .rstrip ("/" ).split ("/" )
313+ if len (parts ) < 2 :
314+ continue
315+
316+ subgraph_dirname = parts [- 1 ]
317+ model_name = parts [- 2 ]
318+
319+ if model_name in prev_map :
320+ active_models_map_for_save [model_name ] = prev_map [model_name ]
321+ if model_name not in tasks_map :
322+ tasks_map [model_name ] = {
323+ "original_path" : prev_map [model_name ],
324+ "split_positions" : set (),
325+ }
326+ else :
327+ continue
328+
329+ try :
330+ sub_idx = int (subgraph_dirname .split ("_" )[- 1 ])
331+ except ValueError :
332+ continue
333+
334+ # 1. Reconstruct previous subgraph size to locate the failing segment
335+ old_split_position = sorted (prev_used_splits .get (model_name , []))
336+ subgraph_size = reconstruct_subgraph_size (old_split_position )
337+
338+ if sub_idx >= len (subgraph_size ):
339+ print (
340+ f"[WARN] Index { sub_idx } out of bounds for { model_name } (old split position: { old_split_position } )"
341+ )
342+ continue
343+
344+ # 2. Get the specific failing subgraph size [Start, End]
345+ fail_start , fail_end = subgraph_size [sub_idx ]
346+
347+ # though intervals logic usually handles this via float('inf') replacement if used.
348+ if fail_end == float ("inf" ):
349+ fail_end = kMaxGraphSize
350+
351+ # Dynamic step calculation
352+ subgraph_size_len = fail_end - fail_start
353+ new_step = subgraph_size_len // 2
354+
355+ if new_step < 1 :
356+ new_step = subgraph_size_len
357+
358+ # 3. Calculate Midpoint
359+ mid_point = fail_start + new_step
360+
361+ # 4. Add split positions
362+ if mid_point > fail_start and mid_point < fail_end :
363+ tasks_map [model_name ]["split_positions" ].update (
364+ [int (fail_start ), int (mid_point ), int (fail_end )]
365+ )
366+ else :
367+ tasks_map [model_name ]["split_positions" ].update (
368+ [int (fail_start ), int (fail_end )]
369+ )
370+
371+ # Recalculate based on current map to ensure log accuracy
372+ real_subgraph_size = calculate_current_subgraph_size (tasks_map , max_subgraph_size )
373+ print (f"[INFO] Current Subgraph Size: { real_subgraph_size } " )
374+ print (f"[INFO] Models to Process: { len (tasks_map )} " )
375+ for model_name , task_info in tasks_map .items ():
376+ original_path = task_info ["original_path" ]
377+ print (f"- { original_path } " )
378+
379+ if not tasks_map :
380+ print (f"[FINISHED] No models need processing." )
243381 sys .exit (0 )
244382
245383 # --- Step 2: Prepare Workspace ---
@@ -250,42 +388,48 @@ def main(args):
250388 # --- Step 3: Decomposition ---
251389 need_decompose = (
252390 True
253- if task_controller .task_scheduler ["run_decomposer" ] and len (target_models ) > 0
391+ if task_controller .task_scheduler ["run_decomposer" ] and len (tasks_map ) > 0
254392 else False
255393 )
256394 if need_decompose :
257395 print ("\n --- Phase 1: Decomposition ---" , flush = True )
258- failed_extraction = []
396+
397+ failed_decomposition = []
398+ final_used_splits_map = {}
259399 while need_decompose :
260400 decomposed_samples_dir = os .path .join (
261401 pass_work_dir , "samples" if args .framework == "torch" else "paddle_samples"
262402 )
263403 os .makedirs (decomposed_samples_dir , exist_ok = True )
264404
265- for idx , model_path in enumerate (target_models ):
266- rectied_model_path = get_rectfied_model_path (model_path )
405+ for model_name , task_info in tasks_map .items ():
406+ original_path = task_info ["original_path" ]
407+ split_positions = sorted (list (task_info ["split_positions" ]))
408+ final_used_splits_map [model_name ] = split_positions
409+
410+ rectied_model_path = get_rectfied_model_path (original_path )
267411 assert os .path .exists (
268412 rectied_model_path
269413 ), f"{ rectied_model_path } does not exist."
270414
271- os .makedirs (decomposed_samples_dir , exist_ok = True )
272415 success = run_decomposer (
273416 args .framework ,
274417 rectied_model_path ,
275418 decomposed_samples_dir ,
276- current_max_size ,
419+ split_positions ,
277420 )
278421 if not success :
279- failed_extraction .append (rectied_model_path )
422+ failed_decomposition .append (rectied_model_path )
423+
280424 num_decomposed_samples = count_samples (decomposed_samples_dir )
281425 print (
282- f"- number of graphs: { len (target_models )} -> { num_decomposed_samples } " ,
426+ f"- number of graphs: { len (tasks_map )} -> { num_decomposed_samples } " ,
283427 flush = True ,
284428 )
285- if failed_extraction :
286- print (f"[WARN] { len (failed_extraction )} models failed to decompose." )
429+ if failed_decomposition :
430+ print (f"[WARN] { len (failed_decomposition )} models failed to decompose." )
287431
288- if num_decomposed_samples == len (target_models ):
432+ if not failed_decomposition and num_decomposed_samples == len (tasks_map ):
289433 need_decompose = True
290434 shutil .rmtree (decomposed_samples_dir )
291435 os .makedirs (decomposed_samples_dir , exist_ok = True )
@@ -300,26 +444,31 @@ def main(args):
300444 run_evaluation (args .framework , args .test_config , pass_work_dir , pass_log_path )
301445
302446 # --- Step 5: Analysis ---
447+ print ("\n --- Phase 3: Analysis ---" )
448+ next_round_models = set (get_incorrect_models (args .tolerance , pass_log_path ))
449+ print (f"[Result] Found { len (next_round_models )} incorrect subgraphs." )
450+
303451 next_round_models = set ()
304452 if task_controller .task_scheduler ["post_analysis" ]:
305453 print ("\n --- Phase 3: Analysis ---" )
306- next_round_models = set ()
307- try :
308- next_round_models = set (get_incorrect_models (args .tolerance , pass_log_path ))
309- print (f" [Result] Found { len (next_round_models )} incorrect subgraphs." )
310- except Exception as e :
311- print (f" [ERROR] Log analysis failed: { e } " )
454+ next_round_models = set (get_incorrect_models (args .tolerance , pass_log_path ))
455+ print (f"[Result] Found { len (next_round_models )} incorrect subgraphs." )
312456
313457 # --- Step 6: Save State ---
314- save_current_config (
315- pass_work_dir , current_max_size , next_round_models , failed_extraction
458+ save_decompose_config (
459+ pass_work_dir ,
460+ real_subgraph_size ,
461+ next_round_models ,
462+ active_models_map_for_save ,
463+ final_used_splits_map ,
464+ failed_decomposition ,
316465 )
317466
318467 print ("\n " + "=" * 80 )
319- if next_round_models and current_max_size > 1 :
468+ if next_round_models and real_subgraph_size > 1 :
320469 print (f">>> [SUGGESTION] Issues remain (Count: { len (next_round_models )} )." )
321470 print (">>> Please start next round decomposition test (Run this script again)." )
322- elif next_round_models and current_max_size <= 1 :
471+ elif next_round_models and real_subgraph_size <= 1 :
323472 print (f">>> [FAILURE] Minimal granularity reached, but errors persist." )
324473 else :
325474 print (f">>> [SUCCESS] Debugging converged." )
0 commit comments