Skip to content

Commit d349727

Browse files
authored
Refactor split_positions calculation and simplify decomposition logic (#412)
1 parent 5fc8829 commit d349727

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,7 @@ def run_decomposer_for_multi_models(
227227
for model_name, task_info in tasks_map.items():
228228
original_path = task_info["original_path"]
229229

230-
split_positions = task_info["split_positions"]
231-
if isinstance(split_positions, set):
232-
split_positions = sorted(list(split_positions))
233-
230+
split_positions = sorted(list(task_info["split_positions"]))
234231
rectified_model_path = get_rectfied_model_path(original_path)
235232
assert os.path.exists(
236233
rectified_model_path
@@ -298,8 +295,9 @@ def calculate_split_positions_for_subgraph(subgraph_size, max_subgraph_size):
298295
end_pos = kMaxGraphSize if end_pos == float("inf") else end_pos
299296

300297
split_positions = list(range(start_pos, end_pos + 1, max_subgraph_size))
301-
deduplicated_splits = list(dict.fromkeys(split_positions))
302-
return deduplicated_splits
298+
if split_positions[-1] != end_pos:
299+
split_positions.append(end_pos)
300+
return sorted(list(set(split_positions)))
303301

304302

305303
def generate_initial_tasks(args):
@@ -321,7 +319,7 @@ def generate_initial_tasks(args):
321319
tasks_map[model_name] = {
322320
"subgraph_path": model_path,
323321
"original_path": model_path,
324-
"split_positions": set(initial_splits),
322+
"split_positions": initial_splits,
325323
}
326324

327325
for task in tasks_map.values():
@@ -448,18 +446,12 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
448446
splits = task_info["split_positions"]
449447
if not splits or len(splits) < 2:
450448
continue
451-
if isinstance(splits, set):
452-
splits = sorted(list(splits))
453449
start_pos = splits[0]
454450
first_segment_end = splits[1]
455-
new_splits = list(
456-
range(start_pos, first_segment_end + 1, max_subgraph_size)
451+
new_splits = calculate_split_positions_for_subgraph(
452+
[start_pos, first_segment_end], max_subgraph_size
457453
)
458-
459-
if new_splits[-1] != first_segment_end:
460-
new_splits.append(first_segment_end)
461-
462-
task_info["split_positions"] = sorted(list(set(new_splits)))
454+
task_info["split_positions"] = new_splits
463455
else:
464456
need_decompose = False
465457
print()

0 commit comments

Comments
 (0)