Skip to content

Commit d7c91a2

Browse files
committed
Opimize codes.
1 parent c067624 commit d7c91a2

File tree

1 file changed

+20
-35
lines changed

1 file changed

+20
-35
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -279,27 +279,18 @@ def run_evaluation(
279279
), f"[ERROR] test failed for {samples_dir}, please check the log."
280280

281281

282-
def reconstruct_subgraph_size(split_positions: List[int]) -> List[list]:
283-
"""Reconstructs subgraph size based on sorted split positions."""
284-
deduplicated_splits = sorted(set(split_positions))
285-
286-
subgraph_size = [
287-
deduplicated_splits[i : i + 2] for i in range(len(deduplicated_splits) - 1)
288-
]
289-
return subgraph_size
290-
291-
292-
def calculate_split_positions_for_subgraph(subgraph_range, max_subgraph_size):
293-
assert isinstance(subgraph_range, (list, tuple)) and len(subgraph_range) == 2
294-
295-
# subgraph_size: the start and end position in original model.
296-
start_pos, end_pos = subgraph_range
297-
end_pos = kMaxGraphSize if end_pos == float("inf") else end_pos
282+
def reconstruct_split_positions_for_subgraph(
283+
split_positions, subgraph_idx, max_subgraph_size
284+
):
285+
assert (
286+
subgraph_idx < len(split_positions) - 1
287+
), f"subgraph_idx {subgraph_idx} is out of bounds of split_positions: {split_positions}."
298288

299-
split_positions = set(
289+
start_pos, end_pos = split_positions[subgraph_idx : subgraph_idx + 2]
290+
new_split_positions = set(
300291
range(start_pos, end_pos + max_subgraph_size - 1, max_subgraph_size)
301292
)
302-
return sorted(list(set(split_positions)))
293+
return sorted(list(new_split_positions))
303294

304295

305296
def generate_initial_tasks(args):
@@ -310,17 +301,14 @@ def generate_initial_tasks(args):
310301
tasks_map = {}
311302
max_subgraph_size = args.max_subgraph_size
312303

304+
initial_split_positions = reconstruct_split_positions_for_subgraph(
305+
[0, kMaxGraphSize], 0, max_subgraph_size
306+
)
313307
for model_path in initial_failures:
314308
model_name = get_model_name_with_subgraph_tag(model_path)
315-
316-
initial_range = [0, kMaxGraphSize]
317-
initial_splits = calculate_split_positions_for_subgraph(
318-
initial_range, max_subgraph_size
319-
)
320-
321309
tasks_map[model_name] = {
322310
"original_path": model_path,
323-
"split_positions": initial_splits,
311+
"split_positions": initial_split_positions,
324312
}
325313

326314
running_states = {
@@ -355,19 +343,14 @@ def generate_successor_tasks(base_output_dir, current_pass_id):
355343

356344
for subgraph_path in sorted(prev_config.incorrect_models):
357345
model_name, subgraph_idx = extract_model_name_and_subgraph_idx(subgraph_path)
346+
print(f"{subgraph_path=}")
358347

359348
assert model_name in prev_tasks_map
360349
pre_task_for_model = prev_tasks_map[model_name]
361350

362351
prev_split_positions = pre_task_for_model.get("split_positions", [])
363-
subgraph_ranges = reconstruct_subgraph_size(prev_split_positions)
364-
365-
assert subgraph_idx < len(
366-
subgraph_ranges
367-
), f"subgraph_idx {subgraph_idx} is out of bounds for {model_name} (previous split_positions: {prev_split_positions})"
368-
369-
split_positions = calculate_split_positions_for_subgraph(
370-
subgraph_ranges[subgraph_idx], max_subgraph_size
352+
split_positions = reconstruct_split_positions_for_subgraph(
353+
prev_split_positions, subgraph_idx, max_subgraph_size
371354
)
372355
if model_name not in tasks_map:
373356
tasks_map[model_name] = {
@@ -381,6 +364,7 @@ def generate_successor_tasks(base_output_dir, current_pass_id):
381364
tasks_map[model_name]["split_positions"] = list(
382365
sorted(set(merged_split_positions))
383366
)
367+
print(f"{tasks_map=}")
384368

385369
return tasks_map, max_subgraph_size, prev_config.running_states
386370

@@ -409,6 +393,7 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
409393
)
410394
sys.exit(0)
411395

396+
sys.exit(0)
412397
return tasks_map, max_subgraph_size, running_states
413398

414399

@@ -450,8 +435,8 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
450435
split_positions = task_info["split_positions"]
451436
if not split_positions or len(split_positions) < 2:
452437
continue
453-
new_split_positions = calculate_split_positions_for_subgraph(
454-
split_positions[0:2], max_subgraph_size
438+
new_split_positions = reconstruct_split_positions_for_subgraph(
439+
split_positions, 0, max_subgraph_size
455440
)
456441
task_info["split_positions"] = new_split_positions
457442
else:

0 commit comments

Comments
 (0)