@@ -227,10 +227,7 @@ def run_decomposer_for_multi_models(
227227 for model_name , task_info in tasks_map .items ():
228228 original_path = task_info ["original_path" ]
229229
230- split_positions = task_info ["split_positions" ]
231- if isinstance (split_positions , set ):
232- split_positions = sorted (list (split_positions ))
233-
230+ split_positions = sorted (list (task_info ["split_positions" ]))
234231 rectified_model_path = get_rectfied_model_path (original_path )
235232 assert os .path .exists (
236233 rectified_model_path
@@ -298,8 +295,9 @@ def calculate_split_positions_for_subgraph(subgraph_size, max_subgraph_size):
298295 end_pos = kMaxGraphSize if end_pos == float ("inf" ) else end_pos
299296
300297 split_positions = list (range (start_pos , end_pos + 1 , max_subgraph_size ))
301- deduplicated_splits = list (dict .fromkeys (split_positions ))
302- return deduplicated_splits
298+ if split_positions [- 1 ] != end_pos :
299+ split_positions .append (end_pos )
300+ return sorted (list (set (split_positions )))
303301
304302
305303def generate_initial_tasks (args ):
@@ -321,7 +319,7 @@ def generate_initial_tasks(args):
321319 tasks_map [model_name ] = {
322320 "subgraph_path" : model_path ,
323321 "original_path" : model_path ,
324- "split_positions" : set ( initial_splits ) ,
322+ "split_positions" : initial_splits ,
325323 }
326324
327325 for task in tasks_map .values ():
@@ -448,18 +446,12 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
448446 splits = task_info ["split_positions" ]
449447 if not splits or len (splits ) < 2 :
450448 continue
451- if isinstance (splits , set ):
452- splits = sorted (list (splits ))
453449 start_pos = splits [0 ]
454450 first_segment_end = splits [1 ]
455- new_splits = list (
456- range ( start_pos , first_segment_end + 1 , max_subgraph_size )
451+ new_splits = calculate_split_positions_for_subgraph (
452+ [ start_pos , first_segment_end ] , max_subgraph_size
457453 )
458-
459- if new_splits [- 1 ] != first_segment_end :
460- new_splits .append (first_segment_end )
461-
462- task_info ["split_positions" ] = sorted (list (set (new_splits )))
454+ task_info ["split_positions" ] = new_splits
463455 else :
464456 need_decompose = False
465457 print ()
0 commit comments