11import os
22import sys
3+ import re
34import json
45import base64
56import shutil
910from typing import List , Set , Dict , Any , Union
1011import graph_net
1112from graph_net .analysis_util import get_incorrect_models
12- from graph_net import path_utils , test_compiler_util
13+ from graph_net import path_utils
1314
1415
1516def convert_b64_string_to_json (b64str ):
@@ -22,16 +23,16 @@ def __init__(self, args):
2223 self .test_config = convert_b64_string_to_json (args .test_config )
2324 assert "test_module_name" in self .test_config
2425
25- test_module_name = self .test_config ["test_module_name" ]
26+ self . test_module_name = self .test_config ["test_module_name" ]
2627 max_pass_id = self ._determine_max_pass_id (self .root_output_dir )
2728 self .current_pass_id = (
28- max_pass_id if test_module_name == "test_target_device" else max_pass_id + 1
29- )
30- print (
31- f"test_module_name: { test_module_name } , current_pass_id: { self .current_pass_id } "
29+ max_pass_id
30+ if self .test_module_name == "test_target_device"
31+ else max_pass_id + 1
3232 )
3333
34- self ._init_task_scheduler (test_module_name )
34+ self ._init_task_scheduler (self .test_module_name )
35+ self ._print ()
3536
3637 def _determine_max_pass_id (self , output_dir : str ) -> int :
3738 """Scans the output directory to determine the next pass ID."""
@@ -71,6 +72,14 @@ def _init_task_scheduler(self, test_module_name):
7172 "post_analysis" : True ,
7273 }
7374
75+ def _print (self ):
76+ print (
77+ f"[TaskController] test_module_name: { self .test_module_name } , current_pass_id: { self .current_pass_id } " ,
78+ flush = True ,
79+ )
80+ print (f"[TaskController] task_scheduler: { self .task_scheduler } " , flush = True )
81+ print ()
82+
7483
7584def get_rectfied_model_path (model_path ):
7685 graphnet_root = path_utils .get_graphnet_root ()
@@ -90,10 +99,13 @@ def get_decompose_config_path(output_dir: str) -> str:
9099 return os .path .join (output_dir , "decompose_config.json" )
91100
92101
93- def load_decompose_config (pass_id : int , output_dir : str ) -> Dict [str , Any ]:
102+ def get_decompose_workspace_path (output_dir , pass_id ):
103+ return os .path .join (output_dir , f"pass_{ pass_id } " )
104+
105+
106+ def load_decompose_config (work_dir : str ) -> Dict [str , Any ]:
94107 """Loads the configuration file from the previous pass."""
95- prev_dir = os .path .join (output_dir , f"pass_{ pass_id - 1 } " )
96- config_path = get_decompose_config_path (prev_dir )
108+ config_path = get_decompose_config_path (work_dir )
97109
98110 if not os .path .exists (config_path ):
99111 raise FileNotFoundError (f"Missing configuration file: { config_path } " )
@@ -125,9 +137,9 @@ def save_decompose_config(
125137
126138
127139def get_model_name_with_subgraph_tag (model_path ):
128- model_name = test_compiler_util . get_model_name ( model_path )
129- subgraph_tag = test_compiler_util . get_subgraph_tag ( model_path )
130- return f"{ model_name } _{ subgraph_tag } " if subgraph_tag else model_name
140+ fields = model_path . rstrip ( "/" ). split ( os . sep )
141+ pattern = rf"^subgraph(_\d+)?$"
142+ return f"{ fields [ - 2 ] } _{ fields [ - 1 ] } " if re . match ( pattern , fields [ - 1 ]) else fields [ - 1 ]
131143
132144
133145def run_decomposer (
@@ -287,14 +299,15 @@ def main(args):
287299 "split_positions" : set (initial_splits ),
288300 }
289301 else :
290- prev_pass_dir = os .path .join (base_output_dir , f"pass_{ current_pass_id - 1 } " )
302+ prev_pass_dir = get_decompose_workspace_path (
303+ base_output_dir , current_pass_id - 1
304+ )
291305 print (
292- f"[Init] Resuming from Pass { current_pass_id - 1 } (Dir: { prev_pass_dir } )..."
306+ f"[Init] Resuming from Pass_ { current_pass_id - 1 } (Dir: { prev_pass_dir } )..."
293307 )
294308
295- prev_config = load_decompose_config (current_pass_id , base_output_dir )
296- prev_map = prev_config .get ("active_models_map" , {})
297-
309+ prev_config = load_decompose_config (prev_pass_dir )
310+ prev_tasks_map = prev_config .get ("active_models_map" , {})
298311 prev_used_splits = prev_config .get ("split_positions_map" , {})
299312 prev_incorrect_subgraphs = prev_config .get ("incorrect_models" , [])
300313
@@ -308,41 +321,35 @@ def main(args):
308321
309322 print (f"[Analysis] Refining splits based on failures..." )
310323
311- for sub_path in prev_incorrect_subgraphs :
312- parts = sub_path .rstrip ("/" ).split ("/" )
313- if len (parts ) < 2 :
314- continue
324+ for subgraph_path in prev_incorrect_subgraphs :
325+ print (f"- subgraph_path: { subgraph_path } " )
326+ model_name_with_subgraph_idx = subgraph_path .rstrip ("/" ).split (os .sep )[- 1 ]
327+ model_name = "_" .join (model_name_with_subgraph_idx .split ("_" )[:- 1 ])
328+ subgraph_idx = int (model_name_with_subgraph_idx .split ("_" )[- 1 ])
329+ print (f"- model_name: { model_name } , subgraph_idx: { subgraph_idx } " )
315330
316- subgraph_dirname = parts [- 1 ]
317- model_name = parts [- 2 ]
318-
319- if model_name in prev_map :
320- active_models_map_for_save [model_name ] = prev_map [model_name ]
331+ if model_name in prev_tasks_map :
332+ active_models_map_for_save [model_name ] = prev_tasks_map [model_name ]
321333 if model_name not in tasks_map :
322334 tasks_map [model_name ] = {
323- "original_path" : prev_map [model_name ],
335+ "original_path" : prev_tasks_map [model_name ],
324336 "split_positions" : set (),
325337 }
326338 else :
327339 continue
328340
329- try :
330- sub_idx = int (subgraph_dirname .split ("_" )[- 1 ])
331- except ValueError :
332- continue
333-
334341 # 1. Reconstruct previous subgraph size to locate the failing segment
335- old_split_position = sorted (prev_used_splits .get (model_name , []))
336- subgraph_size = reconstruct_subgraph_size (old_split_position )
342+ prev_split_positions = sorted (prev_used_splits .get (model_name , []))
343+ subgraph_size = reconstruct_subgraph_size (prev_split_positions )
337344
338- if sub_idx >= len (subgraph_size ):
345+ if subgraph_idx >= len (subgraph_size ):
339346 print (
340- f"[WARN] Index { sub_idx } out of bounds for { model_name } (old split position : { old_split_position } )"
347+ f"[WARN] Index { subgraph_idx } out of bounds for { model_name } (previous split_positions : { prev_split_positions } )"
341348 )
342349 continue
343350
344351 # 2. Get the specific failing subgraph size [Start, End]
345- fail_start , fail_end = subgraph_size [sub_idx ]
352+ fail_start , fail_end = subgraph_size [subgraph_idx ]
346353
347354 # though intervals logic usually handles this via float('inf') replacement if used.
348355 if fail_end == float ("inf" ):
@@ -381,7 +388,7 @@ def main(args):
381388 sys .exit (0 )
382389
383390 # --- Step 2: Prepare Workspace ---
384- pass_work_dir = os . path . join (base_output_dir , f"pass_ { current_pass_id } " )
391+ pass_work_dir = get_decompose_workspace_path (base_output_dir , current_pass_id )
385392 if not os .path .exists (pass_work_dir ):
386393 os .makedirs (pass_work_dir , exist_ok = True )
387394
@@ -401,6 +408,7 @@ def main(args):
401408 pass_work_dir , "samples" if args .framework == "torch" else "paddle_samples"
402409 )
403410 os .makedirs (decomposed_samples_dir , exist_ok = True )
411+ print (f"decomposed_samples_dir: { decomposed_samples_dir } " )
404412
405413 for model_name , task_info in tasks_map .items ():
406414 original_path = task_info ["original_path" ]
@@ -409,6 +417,8 @@ def main(args):
409417 final_used_splits_map [model_name ] = split_positions
410418
411419 rectied_model_path = get_rectfied_model_path (original_path )
420+ print (f"original_path: { original_path } " )
421+ print (f"rectied_model_path: { rectied_model_path } " )
412422 assert os .path .exists (
413423 rectied_model_path
414424 ), f"{ rectied_model_path } does not exist."
0 commit comments