@@ -18,6 +18,30 @@ def convert_b64_string_to_json(b64str):
1818 return json .loads (base64 .b64decode (b64str ).decode ("utf-8" ))
1919
2020
21+ def get_filtered_incorrect_models (tolerance_args : List [int ], log_path : str ) -> set :
22+ if not os .path .exists (log_path ):
23+ return set ()
24+
25+ t_start = tolerance_args [0 ]
26+ models_start = set (get_incorrect_models (t_start , log_path ))
27+
28+ if len (tolerance_args ) == 1 :
29+ return models_start
30+
31+ t_end = tolerance_args [1 ]
32+ models_end = set (get_incorrect_models (t_end , log_path ))
33+
34+ print (f"[Filter] Tolerance Range: { t_start } -> { t_end } " )
35+ print (
36+ f"[Filter] Fail({ t_start } ): { len (models_start )} , Fail({ t_end } ): { len (models_end )} "
37+ )
38+
39+ diff_set = models_start - models_end
40+ print (f"[Filter] Result (Difference): { len (diff_set )} " )
41+
42+ return diff_set
43+
44+
2145class TaskController :
2246 def __init__ (self , args ):
2347 self .root_output_dir = os .path .abspath (args .output_dir )
@@ -290,7 +314,7 @@ def calculate_split_positions_for_subgraph(subgraph_size, max_subgraph_size):
290314def generate_initial_tasks (args ):
291315 """Generates tasks for Pass 0 based on the initial log file."""
292316 print (f"[Init] Pass 0: Reading from log file: { args .log_file } " )
293- initial_failures = get_incorrect_models (args .tolerance , args .log_file )
317+ initial_failures = get_filtered_incorrect_models (args .tolerance , args .log_file )
294318
295319 tasks_map = {}
296320 for model_path in initial_failures :
@@ -487,7 +511,7 @@ def main(args):
487511 next_round_models = set ()
488512 if task_controller .task_scheduler ["post_analysis" ]:
489513 print ("\n --- Phase 3: Analysis ---" )
490- next_round_models = get_incorrect_models (args .tolerance , pass_log_path )
514+ next_round_models = get_filtered_incorrect_models (args .tolerance , pass_log_path )
491515 print (f"[Analysis] Found { len (next_round_models )} incorrect subgraphs.\n " )
492516 if len (next_round_models ) > 0 :
493517 print ("[DEBUG] List of detected incorrect models:" )
@@ -516,7 +540,11 @@ def main(args):
516540 "--test-config" , type = str , required = True , help = "Base64 encoded test config"
517541 )
518542 parser .add_argument (
519- "--tolerance" , type = int , required = True , help = "Tolerance level range [-10, 5)"
543+ "--tolerance" ,
544+ type = int ,
545+ nargs = "+" ,
546+ required = True ,
547+ help = "Tolerance level range [-10, 5)" ,
520548 )
521549 parser .add_argument ("--max-subgraph-size" , type = int , default = 4096 )
522550 args = parser .parse_args ()
0 commit comments