@@ -244,6 +244,12 @@ def run_decomposer_for_multi_models(
244244 original_path = task_info ["original_path" ]
245245
246246 split_positions = sorted (list (task_info ["split_positions" ]))
247+
248+ method = "fixed-start"
249+ if method == "fixed-start" :
250+ assert len (split_positions ) >= 3 , f"{ split_positions = } "
251+ split_positions = [0 , split_positions [1 ]]
252+
247253 rectified_model_path = get_rectfied_model_path (original_path )
248254 assert os .path .exists (
249255 rectified_model_path
@@ -293,27 +299,23 @@ def run_evaluation(
293299 ), f"[ERROR] test failed for { work_dir } , please check the log."
294300
295301
296- def reconstruct_subgraph_size (split_positions : List [int ]) -> List [list ]:
297- """Reconstructs subgraph size based on sorted split positions."""
298- deduplicated_splits = sorted (set (split_positions ))
299-
300- subgraph_size = [
301- deduplicated_splits [i : i + 2 ] for i in range (len (deduplicated_splits ) - 1 )
302- ]
303- return subgraph_size
302+ def reconstruct_split_positions_for_subgraphs (
303+ split_positions , subgraph_idxs , max_subgraph_size
304+ ):
305+ subgraph_idxs = [subgraph_idxs ] if isinstance (subgraph_idxs , int ) else subgraph_idxs
304306
307+ new_split_positions = []
308+ for subgraph_idx in subgraph_idxs :
309+ assert (
310+ subgraph_idx < len (split_positions ) - 1
311+ ), f"subgraph_idx { subgraph_idx } is out of bounds of split_positions: { split_positions } ."
305312
306- def calculate_split_positions_for_subgraph (subgraph_range , max_subgraph_size ):
307- assert isinstance (subgraph_range , (list , tuple )) and len (subgraph_range ) == 2
313+ start_pos , end_pos = split_positions [subgraph_idx : subgraph_idx + 2 ]
314+ new_split_positions = new_split_positions + list (
315+ range (start_pos , end_pos + max_subgraph_size - 1 , max_subgraph_size )
316+ )
308317
309- # subgraph_size: the start and end position in original model.
310- start_pos , end_pos = subgraph_range
311- end_pos = kMaxGraphSize if end_pos == float ("inf" ) else end_pos
312-
313- split_positions = set (
314- range (start_pos , end_pos + max_subgraph_size - 1 , max_subgraph_size )
315- )
316- return list (sorted (split_positions ))
318+ return sorted (list (set (new_split_positions )))
317319
318320
319321def generate_initial_tasks (args ):
@@ -322,19 +324,16 @@ def generate_initial_tasks(args):
322324 initial_failures = get_ranged_incorrect_models (args .tolerance , args .log_file )
323325
324326 tasks_map = {}
325- max_subgraph_size = args .max_subgraph_size
327+ max_subgraph_size = min ( args .max_subgraph_size , kMaxGraphSize // 2 )
326328
329+ initial_split_positions = reconstruct_split_positions_for_subgraphs (
330+ [0 , kMaxGraphSize ], 0 , max_subgraph_size
331+ )
327332 for model_path in initial_failures :
328333 model_name = get_model_name_with_subgraph_tag (model_path )
329-
330- initial_range = [0 , kMaxGraphSize ]
331- initial_splits = calculate_split_positions_for_subgraph (
332- initial_range , max_subgraph_size
333- )
334-
335334 tasks_map [model_name ] = {
336335 "original_path" : model_path ,
337- "split_positions" : initial_splits ,
336+ "split_positions" : initial_split_positions ,
338337 }
339338
340339 running_states = {
@@ -354,7 +353,29 @@ def extract_model_name_and_subgraph_idx(subgraph_path):
354353 return model_name , subgraph_idx
355354
356355
357- def generate_successor_tasks (base_output_dir , current_pass_id ):
356+ def collect_incorrect_subgraph_idxs (args , model_names , incorrect_models ):
357+ model_name2subgraph_idxs = {}
358+ for subgraph_path in sorted (incorrect_models ):
359+ model_name , subgraph_idx = extract_model_name_and_subgraph_idx (subgraph_path )
360+ print (f"{ subgraph_path = } " )
361+
362+ if model_name not in model_name2subgraph_idxs :
363+ model_name2subgraph_idxs [model_name ] = []
364+ model_name2subgraph_idxs [model_name ].append (subgraph_idx )
365+
366+ if args .method == "fixed-start" :
367+ for model_name in model_names :
368+ if model_name not in model_name2subgraph_idxs :
369+ model_name2subgraph_idxs [model_name ] = [1 ]
370+ else :
371+ assert (
372+ len (model_name2subgraph_idxs [model_name ]) == 1
373+ and model_name2subgraph_idxs [model_name ] == 0
374+ )
375+ return model_name2subgraph_idxs
376+
377+
378+ def generate_successor_tasks (args , base_output_dir , current_pass_id ):
358379 """Generates tasks for Pass > 0 based on previous pass results."""
359380 prev_pass_dir = get_decompose_workspace_path (base_output_dir , current_pass_id - 1 )
360381 print (f"[Init] Resuming from Pass_{ current_pass_id - 1 } (Dir: { prev_pass_dir } )..." )
@@ -367,34 +388,24 @@ def generate_successor_tasks(base_output_dir, current_pass_id):
367388 tasks_map = {}
368389 prev_tasks_map = prev_config .tasks_map
369390
370- for subgraph_path in sorted (prev_config .incorrect_models ):
371- model_name , subgraph_idx = extract_model_name_and_subgraph_idx (subgraph_path )
391+ model_name2subgraph_idxs = collect_incorrect_subgraph_idxs (
392+ args , list (prev_tasks_map .keys ()), prev_config .incorrect_models
393+ )
372394
395+ for model_name , subgraph_idxs in model_name2subgraph_idxs .items ():
373396 assert model_name in prev_tasks_map
374397 pre_task_for_model = prev_tasks_map [model_name ]
375398
376399 prev_split_positions = pre_task_for_model .get ("split_positions" , [])
377- subgraph_ranges = reconstruct_subgraph_size (prev_split_positions )
378-
379- assert subgraph_idx < len (
380- subgraph_ranges
381- ), f"subgraph_idx { subgraph_idx } is out of bounds for { model_name } (previous split_positions: { prev_split_positions } )"
382-
383- split_positions = calculate_split_positions_for_subgraph (
384- subgraph_ranges [subgraph_idx ], max_subgraph_size
400+ split_positions = reconstruct_split_positions_for_subgraphs (
401+ prev_split_positions , subgraph_idxs , max_subgraph_size
385402 )
386- if model_name not in tasks_map :
387- tasks_map [model_name ] = {
388- "original_path" : pre_task_for_model ["original_path" ],
389- "split_positions" : list (sorted (split_positions )),
390- }
391- else :
392- merged_split_positions = (
393- tasks_map [model_name ]["split_positions" ] + split_positions
394- )
395- tasks_map [model_name ]["split_positions" ] = list (
396- sorted (set (merged_split_positions ))
397- )
403+
404+ tasks_map [model_name ] = {
405+ "original_path" : pre_task_for_model ["original_path" ],
406+ "split_positions" : split_positions ,
407+ }
408+ print (f"{ tasks_map = } " )
398409
399410 return tasks_map , max_subgraph_size , prev_config .running_states
400411
@@ -404,7 +415,7 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
404415 tasks_map , max_subgraph_size , running_states = generate_initial_tasks (args )
405416 else :
406417 tasks_map , max_subgraph_size , running_states = generate_successor_tasks (
407- base_output_dir , current_pass_id
418+ args , base_output_dir , current_pass_id
408419 )
409420
410421 print (f"[Init] initial max_subgraph_size: { max_subgraph_size } " )
@@ -431,6 +442,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
431442
432443 failed_decomposition = []
433444 need_decompose = True if len (tasks_map ) > 0 else False
445+ method = "fixed-start"
434446
435447 while need_decompose :
436448 decomposed_samples_dir = os .path .join (
@@ -455,6 +467,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
455467 not failed_decomposition
456468 and num_decomposed_samples == len (tasks_map )
457469 and max_subgraph_size > 1
470+ and method != "fixed-start"
458471 ):
459472 need_decompose = True
460473 shutil .rmtree (decomposed_samples_dir )
@@ -464,8 +477,8 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
464477 split_positions = task_info ["split_positions" ]
465478 if not split_positions or len (split_positions ) < 2 :
466479 continue
467- new_split_positions = calculate_split_positions_for_subgraph (
468- split_positions [ 0 : 2 ] , max_subgraph_size
480+ new_split_positions = reconstruct_split_positions_for_subgraphs (
481+ split_positions , 0 , max_subgraph_size
469482 )
470483 task_info ["split_positions" ] = new_split_positions
471484 else :
@@ -579,6 +592,7 @@ def main(args):
579592 parser .add_argument (
580593 "--test-config" , type = str , required = True , help = "Base64 encoded test config"
581594 )
595+ parser .add_argument ("--method" , type = str , required = True )
582596 parser .add_argument (
583597 "--tolerance" ,
584598 type = int ,
0 commit comments