Skip to content

Commit a2f80c6

Browse files
committed
Merge branch 'opt_saved_results' into add_original_names
2 parents 7e95d7f + 7cfd4eb commit a2f80c6

File tree

1 file changed

+27
-18
lines changed

1 file changed

+27
-18
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,9 @@ def _print(self):
111111

112112
@dataclass
113113
class DecomposeConfig:
114+
method: str
115+
tolerance: int | List[int]
114116
max_subgraph_size: int = -1
115-
incorrect_models: List[str] = field(default_factory=list)
116117
tasks_map: Dict[str, Union[int, str, list, dict]] = field(default_factory=dict)
117118
running_states: Dict[str, Union[int, str, list, dict]] = field(default_factory=dict)
118119

@@ -139,6 +140,11 @@ def load(self, work_dir):
139140
def get_config_path(self, work_dir) -> str:
140141
return os.path.join(work_dir, "decompose_config.json")
141142

143+
def get_incorrect_models(self, pass_id):
144+
pass_key = get_pass_name(pass_id)
145+
assert pass_key in self.running_states
146+
return self.running_states[pass_key]["incorrect_models"]
147+
142148
def update_running_states(self, pass_id, **kwargs):
143149
pass_key = get_pass_name(pass_id)
144150
if self.running_states.get(pass_key, None) is None:
@@ -242,7 +248,6 @@ def run_decomposer_for_multi_models(
242248
)
243249
for model_name, task_info in tasks_map.items():
244250
original_path = task_info["original_path"]
245-
246251
split_positions = sorted(list(task_info["split_positions"]))
247252

248253
method = "fixed-start"
@@ -312,9 +317,8 @@ def reconstruct_split_positions_for_subgraphs(
312317

313318
start_pos, end_pos = split_positions[subgraph_idx : subgraph_idx + 2]
314319
new_split_positions = new_split_positions + list(
315-
range(start_pos, end_pos + max_subgraph_size - 1, max_subgraph_size)
320+
range(start_pos, end_pos + max_subgraph_size, max_subgraph_size)
316321
)
317-
318322
return sorted(list(set(new_split_positions)))
319323

320324

@@ -353,25 +357,27 @@ def extract_model_name_and_subgraph_idx(subgraph_path):
353357
return model_name, subgraph_idx
354358

355359

356-
def collect_incorrect_subgraph_idxs(args, model_names, incorrect_models):
360+
def collect_incorrect_subgraph_idxs(args, target_model_names, incorrect_models):
357361
model_name2subgraph_idxs = {}
358362
for subgraph_path in sorted(incorrect_models):
359363
model_name, subgraph_idx = extract_model_name_and_subgraph_idx(subgraph_path)
360364
print(f"{subgraph_path=}")
365+
print(f"{model_name=}, {subgraph_idx=}")
366+
assert model_name in target_model_names, f"{model_name=}, {subgraph_idx=}"
361367

362368
if model_name not in model_name2subgraph_idxs:
363369
model_name2subgraph_idxs[model_name] = []
364370
model_name2subgraph_idxs[model_name].append(subgraph_idx)
365371

366372
if args.method == "fixed-start":
367-
for model_name in model_names:
373+
print(model_name2subgraph_idxs)
374+
for model_name in target_model_names:
368375
if model_name not in model_name2subgraph_idxs:
369376
model_name2subgraph_idxs[model_name] = [1]
370377
else:
371-
assert (
372-
len(model_name2subgraph_idxs[model_name]) == 1
373-
and model_name2subgraph_idxs[model_name] == 0
374-
)
378+
assert len(
379+
model_name2subgraph_idxs[model_name]
380+
) == 1 and model_name2subgraph_idxs[model_name] == [0]
375381
return model_name2subgraph_idxs
376382

377383

@@ -382,18 +388,19 @@ def generate_successor_tasks(args, base_output_dir, current_pass_id):
382388

383389
prev_config = DecomposeConfig.load(prev_pass_dir)
384390
max_subgraph_size = prev_config.max_subgraph_size // 2
385-
if not prev_config.incorrect_models:
391+
incorrect_models = prev_config.get_incorrect_models(current_pass_id)
392+
if args.method != "fixed-start" and not incorrect_models:
386393
return {}, max_subgraph_size, prev_config.running_states
387394

388395
tasks_map = {}
389396
prev_tasks_map = prev_config.tasks_map
390397

398+
target_model_names = list(prev_tasks_map.keys())
391399
model_name2subgraph_idxs = collect_incorrect_subgraph_idxs(
392-
args, list(prev_tasks_map.keys()), prev_config.incorrect_models
400+
args, target_model_names, incorrect_models
393401
)
394402

395403
for model_name, subgraph_idxs in model_name2subgraph_idxs.items():
396-
assert model_name in prev_tasks_map
397404
pre_task_for_model = prev_tasks_map[model_name]
398405

399406
prev_split_positions = pre_task_for_model.get("split_positions", [])
@@ -500,8 +507,7 @@ def count_unique_original_models(incorrect_models):
500507
return len(original_model_paths)
501508

502509

503-
def print_summary_and_suggestion(next_round_models, max_subgraph_size):
504-
"""Print suggestion/result."""
510+
def print_summary_and_suggestion(args, next_round_models, max_subgraph_size):
505511
print("\n" + "=" * 80)
506512
if next_round_models and max_subgraph_size > 1:
507513
print(f">>> [SUGGESTION] Issues remain (Count: {len(next_round_models)}).")
@@ -527,6 +533,8 @@ def main(args):
527533
args, current_pass_id, base_output_dir
528534
)
529535
decompose_config = DecomposeConfig(
536+
method=args.method,
537+
tolerance=args.tolerance,
530538
max_subgraph_size=max_subgraph_size,
531539
tasks_map=tasks_map,
532540
running_states=running_states,
@@ -559,7 +567,6 @@ def main(args):
559567
run_evaluation(args.framework, args.test_config, work_dir, log_path)
560568

561569
# --- Step 4: Analysis ---
562-
next_pass_incorrect_models = set()
563570
if task_controller.task_scheduler["post_analysis"]:
564571
tolerance = (
565572
args.tolerance[0] if isinstance(args.tolerance, list) else args.tolerance
@@ -572,15 +579,17 @@ def main(args):
572579
num_incorrect_models=num_original_models,
573580
incorrect_models=list(next_pass_incorrect_models),
574581
)
582+
575583
print(
576584
f"[Analysis] Found {len(next_pass_incorrect_models)} incorrect subgraphs ({num_original_models} original models)."
577585
)
578586
for idx, model_path in enumerate(next_pass_incorrect_models):
579587
print(f"- [{idx}] {model_path}")
580-
print_summary_and_suggestion(next_pass_incorrect_models, max_subgraph_size)
588+
print_summary_and_suggestion(
589+
args, next_pass_incorrect_models, max_subgraph_size
590+
)
581591

582592
# --- Step 5: Save States ---
583-
decompose_config.incorrect_models = list(next_pass_incorrect_models)
584593
decompose_config.save(work_dir)
585594

586595

0 commit comments

Comments
 (0)