@@ -19,24 +19,32 @@ def convert_b64_string_to_json(b64str):
1919class TaskController :
2020 def __init__ (self , args ):
2121 self .root_output_dir = os .path .abspath (args .output_dir )
22- self .pass_id = self ._determine_current_pass_id (self .root_output_dir )
2322 self .test_config = convert_b64_string_to_json (args .test_config )
2423 assert "test_module_name" in self .test_config
2524
26- self ._init_task_scheduler (self .test_config ["test_module_name" ])
25+ test_module_name = self .test_config ["test_module_name" ]
26+ max_pass_id = self ._determine_max_pass_id (self .root_output_dir )
27+ self .current_pass_id = (
28+ max_pass_id if test_module_name == "test_target_device" else max_pass_id + 1
29+ )
30+ print (
31+ f"test_module_name: { test_module_name } , current_pass_id: { self .current_pass_id } "
32+ )
2733
28- def _determine_current_pass_id (self , output_dir : str ) -> int :
34+ self ._init_task_scheduler (test_module_name )
35+
36+ def _determine_max_pass_id (self , output_dir : str ) -> int :
2937 """Scans the output directory to determine the next pass ID."""
3038 if not os .path .exists (output_dir ):
31- return 0
39+ return - 1
3240 existing_passes = glob .glob (os .path .join (output_dir , "pass_*" ))
3341 valid_ids = []
3442 for p in existing_passes :
3543 basename = os .path .basename (p )
3644 parts = basename .split ("_" )
3745 if len (parts ) == 2 and parts [1 ].isdigit ():
3846 valid_ids .append (int (parts [1 ]))
39- return max (valid_ids ) + 1 if valid_ids else 0
47+ return max (valid_ids ) if valid_ids else - 1
4048
4149 def _init_task_scheduler (self , test_module_name ):
4250 assert test_module_name in [
@@ -177,7 +185,7 @@ def run_evaluation(
177185 test_module_name = test_config ["test_module_name" ]
178186 test_module_arguments = test_config [f"{ test_module_name } _arguments" ]
179187 test_module_arguments ["model-path" ] = work_dir
180- if test_module_name == "test_reference_device" :
188+ if test_module_name in [ "test_reference_device" , "test_target_device" ] :
181189 test_module_arguments ["reference-dir" ] = os .path .join (
182190 work_dir , "reference_device_outputs"
183191 )
@@ -204,7 +212,7 @@ def run_evaluation(
204212def main (args ):
205213 task_controller = TaskController (args )
206214 base_output_dir = task_controller .root_output_dir
207- current_pass_id = task_controller .pass_id
215+ current_pass_id = task_controller .current_pass_id
208216
209217 print ("=" * 80 )
210218 print (f" GraphNet Auto-Debugger | ROUND: PASS_{ current_pass_id } " )
@@ -236,9 +244,8 @@ def main(args):
236244
237245 # --- Step 2: Prepare Workspace ---
238246 pass_work_dir = os .path .join (base_output_dir , f"pass_{ current_pass_id } " )
239- if os .path .exists (pass_work_dir ):
240- shutil .rmtree (pass_work_dir )
241- os .makedirs (pass_work_dir , exist_ok = True )
247+ if not os .path .exists (pass_work_dir ):
248+ os .makedirs (pass_work_dir , exist_ok = True )
242249
243250 # --- Step 3: Decomposition ---
244251 need_decompose = (
@@ -248,28 +255,29 @@ def main(args):
248255 )
249256 if need_decompose :
250257 print ("\n --- Phase 1: Decomposition ---" , flush = True )
258+ failed_extraction = []
251259 while need_decompose :
252- failed_extraction = []
260+ decomposed_samples_dir = os .path .join (
261+ pass_work_dir , "samples" if args .framework == "torch" else "paddle_samples"
262+ )
263+ os .makedirs (decomposed_samples_dir , exist_ok = True )
253264
254265 for idx , model_path in enumerate (target_models ):
255266 rectied_model_path = get_rectfied_model_path (model_path )
256267 assert os .path .exists (
257268 rectied_model_path
258269 ), f"{ rectied_model_path } does not exist."
259270
260- model_name = get_model_name_with_subgraph_tag (rectied_model_path )
261- model_out_dir = os .path .join (pass_work_dir , model_name )
262- os .makedirs (model_out_dir , exist_ok = True )
263-
271+ os .makedirs (decomposed_samples_dir , exist_ok = True )
264272 success = run_decomposer (
265273 args .framework ,
266274 rectied_model_path ,
267- model_out_dir ,
275+ decomposed_samples_dir ,
268276 current_max_size ,
269277 )
270278 if not success :
271279 failed_extraction .append (rectied_model_path )
272- num_decomposed_samples = count_samples (pass_work_dir )
280+ num_decomposed_samples = count_samples (decomposed_samples_dir )
273281 print (
274282 f"- number of graphs: { len (target_models )} -> { num_decomposed_samples } " ,
275283 flush = True ,
@@ -279,8 +287,8 @@ def main(args):
279287
280288 if num_decomposed_samples == len (target_models ):
281289 need_decompose = True
282- shutil .rmtree (pass_work_dir )
283- os .makedirs (pass_work_dir , exist_ok = True )
290+ shutil .rmtree (decomposed_samples_dir )
291+ os .makedirs (decomposed_samples_dir , exist_ok = True )
284292 current_max_size = max (1 , current_max_size // 2 )
285293 else :
286294 need_decompose = False
@@ -331,4 +339,5 @@ def main(args):
331339 )
332340 parser .add_argument ("--max-subgraph-size" , type = int , default = 4096 )
333341 args = parser .parse_args ()
342+ print (args )
334343 main (args )
0 commit comments