@@ -18,6 +18,29 @@ def convert_b64_string_to_json(b64str):
1818 return json .loads (base64 .b64decode (b64str ).decode ("utf-8" ))
1919
2020
21+ def get_ranged_incorrect_models (tolerance_args : List [int ], log_path : str ) -> set :
22+ if not os .path .exists (log_path ):
23+ return set ()
24+
25+ t_start = tolerance_args [0 ]
26+ models_start = set (get_incorrect_models (t_start , log_path ))
27+
28+ if len (tolerance_args ) == 1 :
29+ return models_start
30+
31+ t_end = tolerance_args [1 ]
32+ models_end = set (get_incorrect_models (t_end , log_path ))
33+
34+ print (f"[Filter] Tolerance Range: { t_start } -> { t_end } " )
35+ print (
36+ f"[Filter] Fail({ t_start } ): { len (models_start )} , Fail({ t_end } ): { len (models_end )} "
37+ )
38+
39+ diff_set = models_start - models_end
40+
41+ return diff_set
42+
43+
2144class TaskController :
2245 def __init__ (self , args ):
2346 self .root_output_dir = os .path .abspath (args .output_dir )
@@ -203,11 +226,8 @@ def run_decomposer_for_multi_models(
203226 )
204227 for model_name , task_info in tasks_map .items ():
205228 original_path = task_info ["original_path" ]
206- split_positions = calculate_split_positions_for_subgraph (
207- task_info ["subgraph_size" ], max_subgraph_size
208- )
209- task_info ["split_positions" ] = split_positions
210229
230+ split_positions = sorted (list (task_info ["split_positions" ]))
211231 rectified_model_path = get_rectfied_model_path (original_path )
212232 assert os .path .exists (
213233 rectified_model_path
@@ -275,26 +295,36 @@ def calculate_split_positions_for_subgraph(subgraph_size, max_subgraph_size):
275295 end_pos = kMaxGraphSize if end_pos == float ("inf" ) else end_pos
276296
277297 split_positions = list (range (start_pos , end_pos + 1 , max_subgraph_size ))
278- deduplicated_splits = list (dict .fromkeys (split_positions ))
279- return deduplicated_splits
298+ if split_positions [- 1 ] != end_pos :
299+ split_positions .append (end_pos )
300+ return sorted (list (set (split_positions )))
280301
281302
282303def generate_initial_tasks (args ):
283304 """Generates tasks for Pass 0 based on the initial log file."""
284305 print (f"[Init] Pass 0: Reading from log file: { args .log_file } " )
285- initial_failures = get_incorrect_models (args .tolerance , args .log_file )
306+ initial_failures = get_ranged_incorrect_models (args .tolerance , args .log_file )
286307
287308 tasks_map = {}
309+ max_subgraph_size = args .max_subgraph_size
310+
288311 for model_path in initial_failures :
289312 model_name = get_model_name_with_subgraph_tag (model_path )
313+
314+ initial_range = [0 , kMaxGraphSize ]
315+ initial_splits = calculate_split_positions_for_subgraph (
316+ initial_range , max_subgraph_size
317+ )
318+
290319 tasks_map [model_name ] = {
291320 "subgraph_path" : model_path ,
292321 "original_path" : model_path ,
293- "subgraph_size" : [0 , kMaxGraphSize ],
294- "split_positions" : set (),
322+ "split_positions" : initial_splits ,
295323 }
296324
297- max_subgraph_size = args .max_subgraph_size
325+ for task in tasks_map .values ():
326+ task ["split_positions" ] = sorted (list (task ["split_positions" ]))
327+
298328 return tasks_map , max_subgraph_size
299329
300330
@@ -307,7 +337,6 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
307337 prev_incorrect_subgraphs = prev_config .get ("incorrect_models" , [])
308338 prev_tasks_map = prev_config .get ("tasks_map" , {})
309339
310- # Load previous max size as fallback
311340 prev_max_subgraph_size = prev_config .get ("max_subgraph_size" )
312341 max_subgraph_size = prev_max_subgraph_size // 2
313342
@@ -324,20 +353,30 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
324353 assert model_name in prev_tasks_map
325354 pre_task_for_model = prev_tasks_map [model_name ]
326355
327- # Reconstruct previous subgraph size to locate the failing segment
328356 prev_split_positions = pre_task_for_model .get ("split_positions" , [])
329- subgraph_size = reconstruct_subgraph_size (prev_split_positions )
357+ subgraph_ranges = reconstruct_subgraph_size (prev_split_positions )
358+
330359 assert subgraph_idx < len (
331- subgraph_size
360+ subgraph_ranges
332361 ), f"subgraph_idx { subgraph_idx } is out of bounds for { model_name } (previous split_positions: { prev_split_positions } )"
333362
363+ current_fail_range = subgraph_ranges [subgraph_idx ]
364+
365+ new_splits = calculate_split_positions_for_subgraph (
366+ current_fail_range , max_subgraph_size
367+ )
368+
334369 if model_name not in tasks_map :
335370 tasks_map [model_name ] = {
336371 "subgraph_path" : subgraph_path ,
337372 "original_path" : pre_task_for_model ["original_path" ],
338- "subgraph_size" : subgraph_size [subgraph_idx ],
339- "split_positions" : set (),
373+ "split_positions" : set (new_splits ),
340374 }
375+ else :
376+ tasks_map [model_name ]["split_positions" ].update (new_splits )
377+
378+ for task in tasks_map .values ():
379+ task ["split_positions" ] = sorted (list (task ["split_positions" ]))
341380
342381 return tasks_map , max_subgraph_size
343382
@@ -402,11 +441,17 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
402441 need_decompose = True
403442 shutil .rmtree (decomposed_samples_dir )
404443 os .makedirs (decomposed_samples_dir , exist_ok = True )
444+ max_subgraph_size = max (1 , max_subgraph_size // 2 )
405445 for model_name , task_info in tasks_map .items ():
406- task_info ["subgraph_size" ][1 ] = (
407- task_info ["subgraph_size" ][0 ] + max_subgraph_size
446+ splits = task_info ["split_positions" ]
447+ if not splits or len (splits ) < 2 :
448+ continue
449+ start_pos = splits [0 ]
450+ first_segment_end = splits [1 ]
451+ new_splits = calculate_split_positions_for_subgraph (
452+ [start_pos , first_segment_end ], max_subgraph_size
408453 )
409- max_subgraph_size = max ( 1 , max_subgraph_size // 2 )
454+ task_info [ "split_positions" ] = new_splits
410455 else :
411456 need_decompose = False
412457 print ()
@@ -474,8 +519,18 @@ def main(args):
474519 next_round_models = set ()
475520 if task_controller .task_scheduler ["post_analysis" ]:
476521 print ("\n --- Phase 3: Analysis ---" )
477- next_round_models = get_incorrect_models (args .tolerance , pass_log_path )
522+ analysis_tolerance = (
523+ args .tolerance [0 ] if isinstance (args .tolerance , list ) else args .tolerance
524+ )
525+ next_round_models = get_incorrect_models (analysis_tolerance , pass_log_path )
526+
478527 print (f"[Analysis] Found { len (next_round_models )} incorrect subgraphs.\n " )
528+ if len (next_round_models ) > 0 :
529+ print ("[DEBUG] List of detected incorrect models:" )
530+ for idx , model_path in enumerate (sorted (list (next_round_models ))):
531+ print (f" [{ idx } ] { model_path } " )
532+ else :
533+ print ("[DEBUG] No incorrect models detected." )
479534 print_summary_and_suggestion (next_round_models , max_subgraph_size )
480535
481536 # --- Step 5: Save States ---
@@ -497,7 +552,11 @@ def main(args):
497552 "--test-config" , type = str , required = True , help = "Base64 encoded test config"
498553 )
499554 parser .add_argument (
500- "--tolerance" , type = int , required = True , help = "Tolerance level range [-10, 5)"
555+ "--tolerance" ,
556+ type = int ,
557+ nargs = "+" ,
558+ required = True ,
559+ help = "Tolerance level range [-10, 5)" ,
501560 )
502561 parser .add_argument ("--max-subgraph-size" , type = int , default = 4096 )
503562 args = parser .parse_args ()
0 commit comments