11import os
22import sys
3+ import re
34import json
45import base64
56import shutil
910from typing import List , Set , Dict , Any , Union
1011import graph_net
1112from graph_net .analysis_util import get_incorrect_models
12- from graph_net import path_utils , test_compiler_util
13+ from graph_net import path_utils
14+
15+ kMaxGraphSize = 4096
1316
1417
1518def convert_b64_string_to_json (b64str ):
@@ -22,16 +25,16 @@ def __init__(self, args):
2225 self .test_config = convert_b64_string_to_json (args .test_config )
2326 assert "test_module_name" in self .test_config
2427
25- test_module_name = self .test_config ["test_module_name" ]
28+ self . test_module_name = self .test_config ["test_module_name" ]
2629 max_pass_id = self ._determine_max_pass_id (self .root_output_dir )
2730 self .current_pass_id = (
28- max_pass_id if test_module_name == "test_target_device" else max_pass_id + 1
29- )
30- print (
31- f"test_module_name: { test_module_name } , current_pass_id: { self .current_pass_id } "
31+ max_pass_id
32+ if self .test_module_name == "test_target_device"
33+ else max_pass_id + 1
3234 )
3335
34- self ._init_task_scheduler (test_module_name )
36+ self ._init_task_scheduler (self .test_module_name )
37+ self ._print ()
3538
3639 def _determine_max_pass_id (self , output_dir : str ) -> int :
3740 """Scans the output directory to determine the next pass ID."""
@@ -71,6 +74,14 @@ def _init_task_scheduler(self, test_module_name):
7174 "post_analysis" : True ,
7275 }
7376
77+ def _print (self ):
78+ print (
79+ f"[TaskController] test_module_name: { self .test_module_name } , current_pass_id: { self .current_pass_id } " ,
80+ flush = True ,
81+ )
82+ print (f"[TaskController] task_scheduler: { self .task_scheduler } " , flush = True )
83+ print ()
84+
7485
7586def get_rectfied_model_path (model_path ):
7687 graphnet_root = path_utils .get_graphnet_root ()
@@ -90,10 +101,13 @@ def get_decompose_config_path(output_dir: str) -> str:
90101 return os .path .join (output_dir , "decompose_config.json" )
91102
92103
93- def load_decompose_config (pass_id : int , output_dir : str ) -> Dict [str , Any ]:
104+ def get_decompose_workspace_path (output_dir , pass_id ):
105+ return os .path .join (output_dir , f"pass_{ pass_id } " )
106+
107+
108+ def load_decompose_config (work_dir : str ) -> Dict [str , Any ]:
94109 """Loads the configuration file from the previous pass."""
95- prev_dir = os .path .join (output_dir , f"pass_{ pass_id - 1 } " )
96- config_path = get_decompose_config_path (prev_dir )
110+ config_path = get_decompose_config_path (work_dir )
97111
98112 if not os .path .exists (config_path ):
99113 raise FileNotFoundError (f"Missing configuration file: { config_path } " )
@@ -125,9 +139,9 @@ def save_decompose_config(
125139
126140
127141def get_model_name_with_subgraph_tag (model_path ):
128- model_name = test_compiler_util . get_model_name ( model_path )
129- subgraph_tag = test_compiler_util . get_subgraph_tag ( model_path )
130- return f"{ model_name } _{ subgraph_tag } " if subgraph_tag else model_name
142+ fields = model_path . rstrip ( "/" ). split ( os . sep )
143+ pattern = rf"^subgraph(_\d+)?$"
144+ return f"{ fields [ - 2 ] } _{ fields [ - 1 ] } " if re . match ( pattern , fields [ - 1 ]) else fields [ - 1 ]
131145
132146
133147def run_decomposer (
@@ -256,6 +270,34 @@ def calculate_current_subgraph_size(
256270 )
257271
258272
273+ def calculate_split_postions_for_subgraph (subgraph_size ):
274+ assert isinstance (subgraph_size , (list , tuple )) and len (subgraph_size ) == 2
275+
276+ # Get the specific failing subgraph size [Start, End]
277+ fail_start , fail_end = subgraph_size
278+
279+ # though intervals logic usually handles this via float('inf') replacement if used.
280+ if fail_end == float ("inf" ):
281+ fail_end = kMaxGraphSize
282+
283+ # Dynamic step calculation
284+ subgraph_size_len = fail_end - fail_start
285+ new_step = subgraph_size_len // 2
286+
287+ if new_step < 1 :
288+ new_step = subgraph_size_len
289+
290+ # Calculate Midpoint
291+ mid_point = fail_start + new_step
292+
293+ # Add split positions
294+ if mid_point > fail_start and mid_point < fail_end :
295+ split_positions = [int (fail_start ), int (mid_point ), int (fail_end )]
296+ else :
297+ split_positions = [int (fail_start ), int (fail_end )]
298+ return split_positions
299+
300+
259301def main (args ):
260302 task_controller = TaskController (args )
261303 base_output_dir = task_controller .root_output_dir
@@ -267,7 +309,6 @@ def main(args):
267309
268310 tasks_map = {}
269311 active_models_map_for_save = {}
270- kMaxGraphSize = 4096
271312
272313 # Initialize using the argument passed from bash
273314 max_subgraph_size = args .max_subgraph_size
@@ -287,14 +328,15 @@ def main(args):
287328 "split_positions" : set (initial_splits ),
288329 }
289330 else :
290- prev_pass_dir = os .path .join (base_output_dir , f"pass_{ current_pass_id - 1 } " )
331+ prev_pass_dir = get_decompose_workspace_path (
332+ base_output_dir , current_pass_id - 1
333+ )
291334 print (
292- f"[Init] Resuming from Pass { current_pass_id - 1 } (Dir: { prev_pass_dir } )..."
335+ f"[Init] Resuming from Pass_ { current_pass_id - 1 } (Dir: { prev_pass_dir } )..."
293336 )
294337
295- prev_config = load_decompose_config (current_pass_id , base_output_dir )
296- prev_map = prev_config .get ("active_models_map" , {})
297-
338+ prev_config = load_decompose_config (prev_pass_dir )
339+ prev_active_models_map = prev_config .get ("active_models_map" , {})
298340 prev_used_splits = prev_config .get ("split_positions_map" , {})
299341 prev_incorrect_subgraphs = prev_config .get ("incorrect_models" , [])
300342
@@ -306,67 +348,37 @@ def main(args):
306348 print (f"[FINISHED] Debugging completed." )
307349 sys .exit (0 )
308350
309- print (f"[Analysis] Refining splits based on failures..." )
310-
311- for sub_path in prev_incorrect_subgraphs :
312- parts = sub_path .rstrip ("/" ).split ("/" )
313- if len (parts ) < 2 :
314- continue
315-
316- subgraph_dirname = parts [- 1 ]
317- model_name = parts [- 2 ]
318-
319- if model_name in prev_map :
320- active_models_map_for_save [model_name ] = prev_map [model_name ]
321- if model_name not in tasks_map :
322- tasks_map [model_name ] = {
323- "original_path" : prev_map [model_name ],
324- "split_positions" : set (),
325- }
326- else :
327- continue
328-
329- try :
330- sub_idx = int (subgraph_dirname .split ("_" )[- 1 ])
331- except ValueError :
332- continue
333-
334- # 1. Reconstruct previous subgraph size to locate the failing segment
335- old_split_position = sorted (prev_used_splits .get (model_name , []))
336- subgraph_size = reconstruct_subgraph_size (old_split_position )
337-
338- if sub_idx >= len (subgraph_size ):
339- print (
340- f"[WARN] Index { sub_idx } out of bounds for { model_name } (old split position: { old_split_position } )"
341- )
342- continue
351+ print (f"[Analysis] Refining splits based on previous incorrect models ..." )
343352
344- # 2. Get the specific failing subgraph size [Start, End]
345- fail_start , fail_end = subgraph_size [sub_idx ]
353+ for subgraph_path in prev_incorrect_subgraphs :
354+ print (f"- subgraph_path: { subgraph_path } " )
355+ model_name_with_subgraph_idx = subgraph_path .rstrip ("/" ).split (os .sep )[- 1 ]
356+ model_name = "_" .join (model_name_with_subgraph_idx .split ("_" )[:- 1 ])
357+ subgraph_idx = int (model_name_with_subgraph_idx .split ("_" )[- 1 ])
358+ print (f"- model_name: { model_name } , subgraph_idx: { subgraph_idx } " )
346359
347- # though intervals logic usually handles this via float('inf') replacement if used.
348- if fail_end == float ("inf" ):
349- fail_end = kMaxGraphSize
360+ assert model_name in prev_active_models_map
361+ active_models_map_for_save [model_name ] = prev_active_models_map [model_name ]
350362
351- # Dynamic step calculation
352- subgraph_size_len = fail_end - fail_start
353- new_step = subgraph_size_len // 2
363+ # Reconstruct previous subgraph size to locate the failing segment
364+ prev_split_positions = sorted (prev_used_splits .get (model_name , []))
365+ subgraph_size = reconstruct_subgraph_size (prev_split_positions )
366+ assert subgraph_idx < len (
367+ subgraph_size
368+ ), f"subgraph_idx { subgraph_idx } is out of bounds for { model_name } (previous split_positions: { prev_split_positions } )"
354369
355- if new_step < 1 :
356- new_step = subgraph_size_len
357-
358- # 3. Calculate Midpoint
359- mid_point = fail_start + new_step
360-
361- # 4. Add split positions
362- if mid_point > fail_start and mid_point < fail_end :
363- tasks_map [model_name ]["split_positions" ].update (
364- [int (fail_start ), int (mid_point ), int (fail_end )]
365- )
370+ split_postions = calculate_split_postions_for_subgraph (
371+ subgraph_size [subgraph_idx ]
372+ )
373+ if model_name not in tasks_map :
374+ tasks_map [model_name ] = {
375+ "subgraph_path" : subgraph_path ,
376+ "original_path" : prev_active_models_map [model_name ],
377+ "subgraph_size" : subgraph_size [subgraph_idx ],
378+ "split_positions" : split_postions ,
379+ }
366380 else :
367- tasks_map [model_name ]["split_positions" ].update (
368- [int (fail_start ), int (fail_end )]
369- )
381+ continue
370382
371383 # Recalculate based on current map to ensure log accuracy
372384 real_subgraph_size = calculate_current_subgraph_size (tasks_map , max_subgraph_size )
@@ -381,7 +393,7 @@ def main(args):
381393 sys .exit (0 )
382394
383395 # --- Step 2: Prepare Workspace ---
384- pass_work_dir = os . path . join (base_output_dir , f"pass_ { current_pass_id } " )
396+ pass_work_dir = get_decompose_workspace_path (base_output_dir , current_pass_id )
385397 if not os .path .exists (pass_work_dir ):
386398 os .makedirs (pass_work_dir , exist_ok = True )
387399
@@ -401,6 +413,7 @@ def main(args):
401413 pass_work_dir , "samples" if args .framework == "torch" else "paddle_samples"
402414 )
403415 os .makedirs (decomposed_samples_dir , exist_ok = True )
416+ print (f"decomposed_samples_dir: { decomposed_samples_dir } " )
404417
405418 for model_name , task_info in tasks_map .items ():
406419 original_path = task_info ["original_path" ]
@@ -409,6 +422,8 @@ def main(args):
409422 final_used_splits_map [model_name ] = split_positions
410423
411424 rectied_model_path = get_rectfied_model_path (original_path )
425+ print (f"original_path: { original_path } " )
426+ print (f"rectied_model_path: { rectied_model_path } " )
412427 assert os .path .exists (
413428 rectied_model_path
414429 ), f"{ rectied_model_path } does not exist."
0 commit comments