1212from graph_net import path_utils , test_compiler_util
1313
1414
15+ def convert_b64_string_to_json (b64str ):
16+ return json .loads (base64 .b64decode (b64str ).decode ("utf-8" ))
17+
18+
19+ class TaskController :
20+ def __init__ (self , args ):
21+ self .root_output_dir = os .path .abspath (args .output_dir )
22+ self .pass_id = self ._determine_current_pass_id (self .root_output_dir )
23+ self .test_config = convert_b64_string_to_json (args .test_config )
24+ assert "test_module_name" in self .test_config
25+
26+ self ._init_task_scheduler (self .test_config ["test_module_name" ])
27+
28+ def _determine_current_pass_id (self , output_dir : str ) -> int :
29+ """Scans the output directory to determine the next pass ID."""
30+ if not os .path .exists (output_dir ):
31+ return 0
32+ existing_passes = glob .glob (os .path .join (output_dir , "pass_*" ))
33+ valid_ids = []
34+ for p in existing_passes :
35+ basename = os .path .basename (p )
36+ parts = basename .split ("_" )
37+ if len (parts ) == 2 and parts [1 ].isdigit ():
38+ valid_ids .append (int (parts [1 ]))
39+ return max (valid_ids ) + 1 if valid_ids else 0
40+
41+ def _init_task_scheduler (self , test_module_name ):
42+ assert test_module_name in [
43+ "test_compiler" ,
44+ "test_reference_device" ,
45+ "test_target_device" ,
46+ ]
47+ if test_module_name == "test_compiler" :
48+ self .task_scheduler = {
49+ "run_decomposer" : True ,
50+ "run_evaluation" : True ,
51+ "post_analysis" : True ,
52+ }
53+ elif test_module_name == "test_reference_device" :
54+ self .task_scheduler = {
55+ "run_decomposer" : True ,
56+ "run_evaluation" : True ,
57+ "post_analysis" : False ,
58+ }
59+ elif test_module_name == "test_target_device" :
60+ self .task_scheduler = {
61+ "run_decomposer" : False ,
62+ "run_evaluation" : True ,
63+ "post_analysis" : True ,
64+ }
65+
66+
1567def get_rectfied_model_path (model_path ):
1668 graphnet_root = path_utils .get_graphnet_root ()
1769 return os .path .join (graphnet_root , model_path .split ("GraphNet/" )[- 1 ])
@@ -25,20 +77,6 @@ def count_samples(samples_dir):
2577 return num_samples
2678
2779
28- def determine_current_pass_id (output_dir : str ) -> int :
29- """Scans the output directory to determine the next pass ID."""
30- if not os .path .exists (output_dir ):
31- return 0
32- existing_passes = glob .glob (os .path .join (output_dir , "pass_*" ))
33- valid_ids = []
34- for p in existing_passes :
35- basename = os .path .basename (p )
36- parts = basename .split ("_" )
37- if len (parts ) == 2 and parts [1 ].isdigit ():
38- valid_ids .append (int (parts [1 ]))
39- return max (valid_ids ) + 1 if valid_ids else 0
40-
41-
4280def load_prev_config (pass_id : int , output_dir : str ) -> Dict [str , Any ]:
4381 """Loads the configuration file from the previous pass."""
4482 prev_dir = os .path .join (output_dir , f"pass_{ pass_id - 1 } " )
@@ -78,32 +116,36 @@ def run_decomposer(
78116 model_path : str ,
79117 output_dir : str ,
80118 max_subgraph_size : int ,
81- decorator_config : Dict [str , Any ],
82119) -> bool :
83120 """Decomposes a single model."""
84- # 1. Calculate dynamic split positions
121+
85122 upper_bound = 4096
86123 split_positions = list (
87124 range (max_subgraph_size , upper_bound + max_subgraph_size , max_subgraph_size )
88125 )
89126
90- # 2. Deep copy the template
91- final_decorator_config = json .loads (json .dumps (decorator_config ))
92-
93- # 3. Locate the nested dictionary to inject values
94- decorator_cfg = final_decorator_config ["decorator_config" ]
95-
127+ graphnet_root = path_utils .get_graphnet_root ()
96128 model_name = get_model_name_with_subgraph_tag (model_path )
97- decorator_cfg ["name" ] = model_name
98-
99- custom_cfg = decorator_cfg .get ("custom_extractor_config" , {})
100- custom_cfg ["output_dir" ] = output_dir
101- custom_cfg ["split_positions" ] = split_positions
102-
103- # 4. Encode and Run
104- decorator_config_json = json .dumps (final_decorator_config )
105- decorator_config_b64 = base64 .b64encode (decorator_config_json .encode ()).decode ()
129+ decorator_config = {
130+ "decorator_path" : f"{ graphnet_root } /graph_net/{ framework } /extractor.py" ,
131+ "decorator_config" : {
132+ "name" : model_name ,
133+ "custom_extractor_path" : f"{ graphnet_root } /graph_net/{ framework } /naive_graph_decomposer.py" ,
134+ "custom_extractor_config" : {
135+ "output_dir" : output_dir ,
136+ "split_positions" : split_positions ,
137+ "group_head_and_tail" : True ,
138+ "chain_style" : False ,
139+ },
140+ },
141+ }
142+ decorator_config_b64 = base64 .b64encode (
143+ json .dumps (decorator_config ).encode ()
144+ ).decode ()
106145
146+ print (
147+ f"- [Decomposing] { model_name } (max_subgraph_size={ max_subgraph_size } , split_positions={ split_positions } )"
148+ )
107149 cmd = [
108150 sys .executable ,
109151 "-m" ,
@@ -113,74 +155,64 @@ def run_decomposer(
113155 "--decorator-config" ,
114156 decorator_config_b64 ,
115157 ]
116-
117- print (
118- f"- [Decomposing] { model_name } (max_subgraph_size={ max_subgraph_size } , split_positions={ split_positions } )"
119- )
120-
121158 result = subprocess .run (
122159 cmd , stdout = subprocess .PIPE , stderr = subprocess .PIPE , text = True
123160 )
124161 if result .returncode != 0 :
125- print (f"[ERROR] Decomposition failed for { model_path } \n { result .stderr } " )
162+ print (
163+ f"[ERROR] Decomposition failed for { model_path } \n { result .stderr } " ,
164+ flush = True ,
165+ )
126166 return False
127167 # print(result.stdout)
128168 return True
129169
130170
131- def run_evaluation (test_cmd_b64 : str , work_dir : str , log_path : str ) -> int :
171+ def run_evaluation (
172+ framework : str , test_cmd_b64 : str , work_dir : str , log_path : str
173+ ) -> int :
132174 """Executes the test command on the batch directory."""
133- json_str = base64 .b64decode (test_cmd_b64 ).decode ("utf-8" )
134- cmd_config = json .loads (json_str )
135175
136- # Check if the config follows the new structure: {"module_name": "...", "arguments": {...}}
137- if "module_name" in cmd_config and "arguments" in cmd_config :
138- target_module = cmd_config ["module_name" ]
139- args_dict = cmd_config ["arguments" ]
140- else :
141- # Fallback for old format (flat dictionary), assuming default compiler test
142- target_module = "graph_net.torch.test_compiler"
143- args_dict = cmd_config
144-
145- cmd = [sys .executable , "-m" , target_module ]
146-
147- for key , value in args_dict .items ():
148- if key == "model_path" :
149- continue
150- cmd .append (f"--{ key } " )
151- cmd .append (str (value ))
176+ test_config = convert_b64_string_to_json (test_cmd_b64 )
177+ test_module_name = test_config ["test_module_name" ]
178+ test_module_arguments = test_config [f"{ test_module_name } _arguments" ]
179+ test_module_arguments ["model-path" ] = work_dir
180+ if test_module_name == "test_reference_device" :
181+ test_module_arguments ["reference-dir" ] = os .path .join (
182+ work_dir , "reference_device_outputs"
183+ )
152184
153- cmd .append ("--model-path" )
154- cmd .append (work_dir )
185+ cmd = [sys .executable , "-m" , f"graph_net.{ framework } .{ test_module_name } " ] + [
186+ item
187+ for key , value in test_module_arguments .items ()
188+ for item in (f"--{ key } " , str (value ))
189+ ]
155190
156191 print (f" [Batch Testing] Logging to: { log_path } " )
157192 print (f" [Command] { ' ' .join (cmd )} " )
158193
159194 os .makedirs (os .path .dirname (log_path ), exist_ok = True )
160195 with open (log_path , "w" ) as f :
161196 proc = subprocess .run (cmd , stdout = f , stderr = subprocess .STDOUT , text = True )
162- return proc .returncode
197+ if proc .returncode != 0 :
198+ with open (log_path , "r" ) as f :
199+ content = f .read ()
200+ print (f"[ERROR] test failed for { work_dir } \n { content } " , flush = True )
201+ sys .exit (proc .returncode )
163202
164203
165- # ==========================================================
166- # Main Execution Flow
167- # ==========================================================
168204def main (args ):
169- base_output_dir = os .path .abspath (args .output_dir )
170- current_pass_id = determine_current_pass_id (base_output_dir )
171-
172- # Parse the Decorator Configuration Template
173- tpl_str = base64 .b64decode (args .decorator_config ).decode ("utf-8" )
174- decorator_template = json .loads (tpl_str )
205+ task_controller = TaskController (args )
206+ base_output_dir = task_controller .root_output_dir
207+ current_pass_id = task_controller .pass_id
175208
176209 print ("=" * 80 )
177210 print (f" GraphNet Auto-Debugger | ROUND: PASS_{ current_pass_id } " )
178211 print ("=" * 80 )
179212
180213 # --- Step 1: Initialize State ---
181214 target_models = []
182- current_max_size = 2048
183-
215+ current_max_size = args .max_subgraph_size
184216 if current_pass_id == 0 :
185217 print (f"[Init] Pass 0: Reading from log file: { args .log_file } " )
186218 current_max_size = args .max_subgraph_size
@@ -209,8 +241,13 @@ def main(args):
209241 os .makedirs (pass_work_dir , exist_ok = True )
210242
211243 # --- Step 3: Decomposition ---
212- print ("\n --- Phase 1: Decomposition ---" , flush = True )
213- need_decompose = True if len (target_models ) > 0 else False
244+ need_decompose = (
245+ True
246+ if task_controller .task_scheduler ["run_decomposer" ] and len (target_models ) > 0
247+ else False
248+ )
249+ if need_decompose :
250+ print ("\n --- Phase 1: Decomposition ---" , flush = True )
214251 while need_decompose :
215252 failed_extraction = []
216253
@@ -229,7 +266,6 @@ def main(args):
229266 rectied_model_path ,
230267 model_out_dir ,
231268 current_max_size ,
232- decorator_template ,
233269 )
234270 if not success :
235271 failed_extraction .append (rectied_model_path )
@@ -250,18 +286,21 @@ def main(args):
250286 need_decompose = False
251287
252288 # --- Step 4: Testing ---
253- print ("\n --- Phase 2: Batch Testing ---" )
254- pass_log_path = os .path .join (pass_work_dir , "batch_test_result.log" )
255- run_evaluation (args .test_config , pass_work_dir , pass_log_path )
289+ if task_controller .task_scheduler ["run_evaluation" ]:
290+ print ("\n --- Phase 2: Batch Testing ---" )
291+ pass_log_path = os .path .join (pass_work_dir , "batch_test_result.log" )
292+ run_evaluation (args .framework , args .test_config , pass_work_dir , pass_log_path )
256293
257294 # --- Step 5: Analysis ---
258- print ("\n --- Phase 3: Analysis ---" )
259295 next_round_models = set ()
260- try :
261- next_round_models = set (get_incorrect_models (args .tolerance , pass_log_path ))
262- print (f" [Result] Found { len (next_round_models )} incorrect subgraphs." )
263- except Exception as e :
264- print (f" [ERROR] Log analysis failed: { e } " )
296+ if task_controller .task_scheduler ["post_analysis" ]:
297+ print ("\n --- Phase 3: Analysis ---" )
298+ next_round_models = set ()
299+ try :
300+ next_round_models = set (get_incorrect_models (args .tolerance , pass_log_path ))
301+ print (f" [Result] Found { len (next_round_models )} incorrect subgraphs." )
302+ except Exception as e :
303+ print (f" [ERROR] Log analysis failed: { e } " )
265304
266305 # --- Step 6: Save State ---
267306 save_current_config (
@@ -287,16 +326,9 @@ def main(args):
287326 parser .add_argument (
288327 "--test-config" , type = str , required = True , help = "Base64 encoded test config"
289328 )
290- parser .add_argument (
291- "--decorator-config" ,
292- type = str ,
293- required = True ,
294- help = "Base64 encoded decorator config template" ,
295- )
296329 parser .add_argument (
297330 "--tolerance" , type = int , required = True , help = "Tolerance level range [-10, 5)"
298331 )
299332 parser .add_argument ("--max-subgraph-size" , type = int , default = 4096 )
300-
301333 args = parser .parse_args ()
302334 main (args )
0 commit comments