Skip to content

Commit 1196549

Browse files
committed
Optimzie codes.
1 parent 74f423e commit 1196549

File tree

1 file changed

+58
-52
lines changed

1 file changed

+58
-52
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 58 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,19 @@ def load(self, work_dir):
139139
def get_config_path(self, work_dir) -> str:
140140
return os.path.join(work_dir, "decompose_config.json")
141141

142+
def update_running_states(self, pass_id, **kwargs):
143+
pass_key = get_pass_name(pass_id)
144+
if self.running_states.get(pass_key, None) is None:
145+
self.running_states[pass_key] = {}
146+
147+
for key, value in kwargs.items():
148+
assert key in [
149+
"num_incorrect_models",
150+
"incorrect_models",
151+
"failed_decomposition_models",
152+
]
153+
self.running_states[pass_key][key] = value
154+
142155

143156
def get_rectfied_model_path(model_path):
144157
graphnet_root = path_utils.get_graphnet_root()
@@ -268,11 +281,10 @@ def run_evaluation(
268281

269282
def reconstruct_subgraph_size(split_positions: List[int]) -> List[list]:
270283
"""Reconstructs subgraph size based on sorted split positions."""
271-
deduplicated_splits = list(dict.fromkeys(split_positions))
284+
deduplicated_splits = sorted(set(split_positions))
272285

273286
subgraph_size = [
274-
[deduplicated_splits[i], deduplicated_splits[i + 1]]
275-
for i in range(len(deduplicated_splits) - 1)
287+
deduplicated_splits[i : i + 2] for i in range(len(deduplicated_splits) - 1)
276288
]
277289
return subgraph_size
278290

@@ -328,7 +340,7 @@ def extract_model_name_and_subgraph_idx(subgraph_path):
328340
return model_name, subgraph_idx
329341

330342

331-
def generate_refined_tasks(base_output_dir, current_pass_id):
343+
def generate_successor_tasks(base_output_dir, current_pass_id):
332344
"""Generates tasks for Pass > 0 based on previous pass results."""
333345
prev_pass_dir = get_decompose_workspace_path(base_output_dir, current_pass_id - 1)
334346
print(f"[Init] Resuming from Pass_{current_pass_id - 1} (Dir: {prev_pass_dir})...")
@@ -377,7 +389,7 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
377389
if current_pass_id == 0:
378390
tasks_map, max_subgraph_size, running_states = generate_initial_tasks(args)
379391
else:
380-
tasks_map, max_subgraph_size, running_states = generate_refined_tasks(
392+
tasks_map, max_subgraph_size, running_states = generate_successor_tasks(
381393
base_output_dir, current_pass_id
382394
)
383395

@@ -435,15 +447,13 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
435447
os.makedirs(decomposed_samples_dir, exist_ok=True)
436448
max_subgraph_size = max(1, max_subgraph_size // 2)
437449
for model_name, task_info in tasks_map.items():
438-
splits = task_info["split_positions"]
439-
if not splits or len(splits) < 2:
450+
split_positions = task_info["split_positions"]
451+
if not split_positions or len(split_positions) < 2:
440452
continue
441-
start_pos = splits[0]
442-
first_segment_end = splits[1]
443-
new_splits = calculate_split_positions_for_subgraph(
444-
[start_pos, first_segment_end], max_subgraph_size
453+
new_split_positions = calculate_split_positions_for_subgraph(
454+
split_positions[0:2], max_subgraph_size
445455
)
446-
task_info["split_positions"] = new_splits
456+
task_info["split_positions"] = new_split_positions
447457
else:
448458
need_decompose = False
449459
print()
@@ -454,6 +464,15 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
454464
return tasks_map, failed_decomposition, max_subgraph_size
455465

456466

467+
def count_unique_original_models(incorrect_models):
468+
original_model_paths = set(
469+
model_name
470+
for subgraph_path in incorrect_models
471+
for model_name, _ in [extract_model_name_and_subgraph_idx(subgraph_path)]
472+
)
473+
return len(original_model_paths)
474+
475+
457476
def print_summary_and_suggestion(next_round_models, max_subgraph_size):
458477
"""Print suggestion/result."""
459478
print("\n" + "=" * 80)
@@ -480,9 +499,14 @@ def main(args):
480499
tasks_map, max_subgraph_size, running_states = prepare_tasks_and_verify(
481500
args, current_pass_id, base_output_dir
482501
)
483-
pass_work_dir = get_decompose_workspace_path(base_output_dir, current_pass_id)
484-
if not os.path.exists(pass_work_dir):
485-
os.makedirs(pass_work_dir, exist_ok=True)
502+
decompose_config = DecomposeConfig(
503+
max_subgraph_size=max_subgraph_size,
504+
tasks_map=tasks_map,
505+
running_states=running_states,
506+
)
507+
work_dir = get_decompose_workspace_path(base_output_dir, current_pass_id)
508+
if not os.path.exists(work_dir):
509+
os.makedirs(work_dir, exist_ok=True)
486510

487511
# --- Step 2: Decomposition ---
488512
if task_controller.task_scheduler["run_decomposer"]:
@@ -492,63 +516,45 @@ def main(args):
492516
failed_decomposition,
493517
max_subgraph_size,
494518
) = execute_decomposition_phase(
495-
max_subgraph_size, tasks_map, args.framework, pass_work_dir
519+
max_subgraph_size, tasks_map, args.framework, work_dir
520+
)
521+
decompose_config.update_running_states(
522+
current_pass_id, failed_decomposition_models=list(failed_decomposition)
496523
)
497-
running_states.get(f"pass_{current_pass_id}", {})[
498-
"failed_decomposition_models"
499-
] = list(failed_decomposition)
500524
else:
501525
print("\n--- Phase 1: Decomposition (skipped) ---", flush=True)
502-
config = DecomposeConfig.load(pass_work_dir)
503-
max_subgraph_size = config.max_subgraph_size
504-
tasks_map = config.tasks_map
505-
running_states = config.running_states
526+
decompose_config = DecomposeConfig.load(work_dir)
506527

507528
# --- Step 3: Evaluation ---
508-
pass_log_path = os.path.join(pass_work_dir, "batch_test_result.log")
529+
log_path = os.path.join(work_dir, f"log_{task_controller.test_module_name}.txt")
509530
if task_controller.task_scheduler["run_evaluation"]:
510531
print(f"\n--- Phase 2: Evaluation ({task_controller.test_module_name}) ---")
511-
run_evaluation(args.framework, args.test_config, pass_work_dir, pass_log_path)
532+
run_evaluation(args.framework, args.test_config, work_dir, log_path)
512533

513534
# --- Step 4: Analysis ---
514-
next_round_models = set()
535+
next_pass_incorrect_models = set()
515536
if task_controller.task_scheduler["post_analysis"]:
516537
tolerance = (
517538
args.tolerance[0] if isinstance(args.tolerance, list) else args.tolerance
518539
)
519540
print(f"\n--- Phase 3: Analysis (torlance={tolerance}) ---")
520-
next_round_models = sorted(get_incorrect_models(tolerance, pass_log_path))
521-
original_model_paths = set(
522-
[
523-
model_name
524-
for subgraph_path in next_round_models
525-
for model_name, _ in [
526-
extract_model_name_and_subgraph_idx(subgraph_path)
527-
]
528-
]
541+
next_pass_incorrect_models = sorted(get_incorrect_models(tolerance, log_path))
542+
num_original_models = count_unique_original_models(next_pass_incorrect_models)
543+
decompose_config.update_running_states(
544+
current_pass_id + 1,
545+
num_incorrect_models=num_original_models,
546+
incorrect_models=list(next_pass_incorrect_models),
529547
)
530-
531-
running_states[f"pass_{current_pass_id + 1}"] = {
532-
"num_incorrect_models": len(original_model_paths),
533-
"incorrect_models": list(next_round_models),
534-
}
535-
536548
print(
537-
f"[Analysis] Found {len(next_round_models)} incorrect subgraphs ({len(original_model_paths)} original models)."
549+
f"[Analysis] Found {len(next_pass_incorrect_models)} incorrect subgraphs ({num_original_models} original models)."
538550
)
539-
for idx, model_path in enumerate(next_round_models):
551+
for idx, model_path in enumerate(next_pass_incorrect_models):
540552
print(f"- [{idx}] {model_path}")
541-
542-
print_summary_and_suggestion(next_round_models, max_subgraph_size)
553+
print_summary_and_suggestion(next_pass_incorrect_models, max_subgraph_size)
543554

544555
# --- Step 5: Save States ---
545-
config = DecomposeConfig(
546-
max_subgraph_size=max_subgraph_size,
547-
incorrect_models=list(next_round_models),
548-
tasks_map=tasks_map,
549-
running_states=running_states,
550-
)
551-
config.save(pass_work_dir)
556+
decompose_config.incorrect_models = list(next_pass_incorrect_models)
557+
decompose_config.save(work_dir)
552558

553559

554560
if __name__ == "__main__":

0 commit comments

Comments
 (0)