Skip to content

Commit 7e95d7f

Browse files
committed
Merge branch 'opt_saved_results' into add_original_names
2 parents d310856 + 00b070d commit 7e95d7f

File tree

1 file changed

+66
-52
lines changed

1 file changed

+66
-52
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 66 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,12 @@ def run_decomposer_for_multi_models(
244244
original_path = task_info["original_path"]
245245

246246
split_positions = sorted(list(task_info["split_positions"]))
247+
248+
method = "fixed-start"
249+
if method == "fixed-start":
250+
assert len(split_positions) >= 3, f"{split_positions=}"
251+
split_positions = [0, split_positions[1]]
252+
247253
rectified_model_path = get_rectfied_model_path(original_path)
248254
assert os.path.exists(
249255
rectified_model_path
@@ -293,27 +299,23 @@ def run_evaluation(
293299
), f"[ERROR] test failed for {work_dir}, please check the log."
294300

295301

296-
def reconstruct_subgraph_size(split_positions: List[int]) -> List[list]:
297-
"""Reconstructs subgraph size based on sorted split positions."""
298-
deduplicated_splits = sorted(set(split_positions))
299-
300-
subgraph_size = [
301-
deduplicated_splits[i : i + 2] for i in range(len(deduplicated_splits) - 1)
302-
]
303-
return subgraph_size
302+
def reconstruct_split_positions_for_subgraphs(
303+
split_positions, subgraph_idxs, max_subgraph_size
304+
):
305+
subgraph_idxs = [subgraph_idxs] if isinstance(subgraph_idxs, int) else subgraph_idxs
304306

307+
new_split_positions = []
308+
for subgraph_idx in subgraph_idxs:
309+
assert (
310+
subgraph_idx < len(split_positions) - 1
311+
), f"subgraph_idx {subgraph_idx} is out of bounds of split_positions: {split_positions}."
305312

306-
def calculate_split_positions_for_subgraph(subgraph_range, max_subgraph_size):
307-
assert isinstance(subgraph_range, (list, tuple)) and len(subgraph_range) == 2
313+
start_pos, end_pos = split_positions[subgraph_idx : subgraph_idx + 2]
314+
new_split_positions = new_split_positions + list(
315+
range(start_pos, end_pos + max_subgraph_size - 1, max_subgraph_size)
316+
)
308317

309-
# subgraph_size: the start and end position in original model.
310-
start_pos, end_pos = subgraph_range
311-
end_pos = kMaxGraphSize if end_pos == float("inf") else end_pos
312-
313-
split_positions = set(
314-
range(start_pos, end_pos + max_subgraph_size - 1, max_subgraph_size)
315-
)
316-
return list(sorted(split_positions))
318+
return sorted(list(set(new_split_positions)))
317319

318320

319321
def generate_initial_tasks(args):
@@ -322,19 +324,16 @@ def generate_initial_tasks(args):
322324
initial_failures = get_ranged_incorrect_models(args.tolerance, args.log_file)
323325

324326
tasks_map = {}
325-
max_subgraph_size = args.max_subgraph_size
327+
max_subgraph_size = min(args.max_subgraph_size, kMaxGraphSize // 2)
326328

329+
initial_split_positions = reconstruct_split_positions_for_subgraphs(
330+
[0, kMaxGraphSize], 0, max_subgraph_size
331+
)
327332
for model_path in initial_failures:
328333
model_name = get_model_name_with_subgraph_tag(model_path)
329-
330-
initial_range = [0, kMaxGraphSize]
331-
initial_splits = calculate_split_positions_for_subgraph(
332-
initial_range, max_subgraph_size
333-
)
334-
335334
tasks_map[model_name] = {
336335
"original_path": model_path,
337-
"split_positions": initial_splits,
336+
"split_positions": initial_split_positions,
338337
}
339338

340339
running_states = {
@@ -354,7 +353,29 @@ def extract_model_name_and_subgraph_idx(subgraph_path):
354353
return model_name, subgraph_idx
355354

356355

357-
def generate_successor_tasks(base_output_dir, current_pass_id):
356+
def collect_incorrect_subgraph_idxs(args, model_names, incorrect_models):
357+
model_name2subgraph_idxs = {}
358+
for subgraph_path in sorted(incorrect_models):
359+
model_name, subgraph_idx = extract_model_name_and_subgraph_idx(subgraph_path)
360+
print(f"{subgraph_path=}")
361+
362+
if model_name not in model_name2subgraph_idxs:
363+
model_name2subgraph_idxs[model_name] = []
364+
model_name2subgraph_idxs[model_name].append(subgraph_idx)
365+
366+
if args.method == "fixed-start":
367+
for model_name in model_names:
368+
if model_name not in model_name2subgraph_idxs:
369+
model_name2subgraph_idxs[model_name] = [1]
370+
else:
371+
assert (
372+
len(model_name2subgraph_idxs[model_name]) == 1
373+
and model_name2subgraph_idxs[model_name] == 0
374+
)
375+
return model_name2subgraph_idxs
376+
377+
378+
def generate_successor_tasks(args, base_output_dir, current_pass_id):
358379
"""Generates tasks for Pass > 0 based on previous pass results."""
359380
prev_pass_dir = get_decompose_workspace_path(base_output_dir, current_pass_id - 1)
360381
print(f"[Init] Resuming from Pass_{current_pass_id - 1} (Dir: {prev_pass_dir})...")
@@ -367,34 +388,24 @@ def generate_successor_tasks(base_output_dir, current_pass_id):
367388
tasks_map = {}
368389
prev_tasks_map = prev_config.tasks_map
369390

370-
for subgraph_path in sorted(prev_config.incorrect_models):
371-
model_name, subgraph_idx = extract_model_name_and_subgraph_idx(subgraph_path)
391+
model_name2subgraph_idxs = collect_incorrect_subgraph_idxs(
392+
args, list(prev_tasks_map.keys()), prev_config.incorrect_models
393+
)
372394

395+
for model_name, subgraph_idxs in model_name2subgraph_idxs.items():
373396
assert model_name in prev_tasks_map
374397
pre_task_for_model = prev_tasks_map[model_name]
375398

376399
prev_split_positions = pre_task_for_model.get("split_positions", [])
377-
subgraph_ranges = reconstruct_subgraph_size(prev_split_positions)
378-
379-
assert subgraph_idx < len(
380-
subgraph_ranges
381-
), f"subgraph_idx {subgraph_idx} is out of bounds for {model_name} (previous split_positions: {prev_split_positions})"
382-
383-
split_positions = calculate_split_positions_for_subgraph(
384-
subgraph_ranges[subgraph_idx], max_subgraph_size
400+
split_positions = reconstruct_split_positions_for_subgraphs(
401+
prev_split_positions, subgraph_idxs, max_subgraph_size
385402
)
386-
if model_name not in tasks_map:
387-
tasks_map[model_name] = {
388-
"original_path": pre_task_for_model["original_path"],
389-
"split_positions": list(sorted(split_positions)),
390-
}
391-
else:
392-
merged_split_positions = (
393-
tasks_map[model_name]["split_positions"] + split_positions
394-
)
395-
tasks_map[model_name]["split_positions"] = list(
396-
sorted(set(merged_split_positions))
397-
)
403+
404+
tasks_map[model_name] = {
405+
"original_path": pre_task_for_model["original_path"],
406+
"split_positions": split_positions,
407+
}
408+
print(f"{tasks_map=}")
398409

399410
return tasks_map, max_subgraph_size, prev_config.running_states
400411

@@ -404,7 +415,7 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
404415
tasks_map, max_subgraph_size, running_states = generate_initial_tasks(args)
405416
else:
406417
tasks_map, max_subgraph_size, running_states = generate_successor_tasks(
407-
base_output_dir, current_pass_id
418+
args, base_output_dir, current_pass_id
408419
)
409420

410421
print(f"[Init] initial max_subgraph_size: {max_subgraph_size}")
@@ -431,6 +442,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
431442

432443
failed_decomposition = []
433444
need_decompose = True if len(tasks_map) > 0 else False
445+
method = "fixed-start"
434446

435447
while need_decompose:
436448
decomposed_samples_dir = os.path.join(
@@ -455,6 +467,7 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
455467
not failed_decomposition
456468
and num_decomposed_samples == len(tasks_map)
457469
and max_subgraph_size > 1
470+
and method != "fixed-start"
458471
):
459472
need_decompose = True
460473
shutil.rmtree(decomposed_samples_dir)
@@ -464,8 +477,8 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
464477
split_positions = task_info["split_positions"]
465478
if not split_positions or len(split_positions) < 2:
466479
continue
467-
new_split_positions = calculate_split_positions_for_subgraph(
468-
split_positions[0:2], max_subgraph_size
480+
new_split_positions = reconstruct_split_positions_for_subgraphs(
481+
split_positions, 0, max_subgraph_size
469482
)
470483
task_info["split_positions"] = new_split_positions
471484
else:
@@ -579,6 +592,7 @@ def main(args):
579592
parser.add_argument(
580593
"--test-config", type=str, required=True, help="Base64 encoded test config"
581594
)
595+
parser.add_argument("--method", type=str, required=True)
582596
parser.add_argument(
583597
"--tolerance",
584598
type=int,

0 commit comments

Comments
 (0)