Skip to content

Commit c75573e

Browse files
committed
Optimize the main function
1 parent ad4b5da commit c75573e

File tree

1 file changed

+149
-126
lines changed

1 file changed

+149
-126
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 149 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -296,97 +296,168 @@ def calculate_split_positions_for_subgraph(subgraph_size):
296296
return split_positions
297297

298298

299-
def main(args):
300-
task_controller = TaskController(args)
301-
base_output_dir = task_controller.root_output_dir
302-
current_pass_id = task_controller.current_pass_id
299+
def generate_initial_tasks(log_file, tolerance, max_subgraph_size):
300+
"""Generates tasks for Pass 0 based on the initial log file."""
301+
print(f"[Init] Pass 0: Reading from log file: {log_file}")
302+
initial_failures = get_incorrect_models(tolerance, log_file)
303303

304-
print("=" * 80)
305-
print(f" GraphNet Auto-Debugger | ROUND: PASS_{current_pass_id}")
306-
print("=" * 80)
304+
# Dynamic generation based on step size
305+
initial_splits = list(range(0, kMaxGraphSize + 1, max_subgraph_size))
307306

308307
tasks_map = {}
309308
active_models_map_for_save = {}
310309

311-
# Initialize using the argument passed from bash
312-
max_subgraph_size = args.max_subgraph_size
310+
for path in initial_failures:
311+
name = os.path.basename(path.rstrip("/"))
312+
active_models_map_for_save[name] = path
313+
tasks_map[name] = {
314+
"original_path": path,
315+
"split_positions": set(initial_splits),
316+
}
317+
return tasks_map, active_models_map_for_save, max_subgraph_size
313318

314-
if current_pass_id == 0:
315-
print(f"[Init] Pass 0: Reading from log file: {args.log_file}")
316-
initial_failures = get_incorrect_models(args.tolerance, args.log_file)
317-
318-
# Dynamic generation based on step size (args.max_subgraph_size)
319-
initial_splits = list(range(0, kMaxGraphSize + 1, max_subgraph_size))
320-
321-
for path in initial_failures:
322-
name = os.path.basename(path.rstrip("/"))
323-
active_models_map_for_save[name] = path
324-
tasks_map[name] = {
325-
"original_path": path,
326-
"split_positions": set(initial_splits),
327-
}
328-
else:
329-
prev_pass_dir = get_decompose_workspace_path(
330-
base_output_dir, current_pass_id - 1
319+
320+
def generate_refined_tasks(base_output_dir, current_pass_id, default_max_size):
321+
"""Generates tasks for Pass > 0 based on previous pass results."""
322+
prev_pass_dir = get_decompose_workspace_path(base_output_dir, current_pass_id - 1)
323+
print(f"[Init] Resuming from Pass_{current_pass_id - 1} (Dir: {prev_pass_dir})...")
324+
325+
prev_config = load_decompose_config(prev_pass_dir)
326+
prev_active_models_map = prev_config.get("active_models_map", {})
327+
prev_used_splits = prev_config.get("split_positions_map", {})
328+
prev_incorrect_subgraphs = prev_config.get("incorrect_models", [])
329+
330+
# Load previous max size as fallback
331+
max_subgraph_size = prev_config.get("max_subgraph_size", default_max_size)
332+
333+
if not prev_incorrect_subgraphs:
334+
return {}, {}, max_subgraph_size
335+
336+
print("[Analysis] Refining splits based on previous incorrect models ...")
337+
338+
tasks_map = {}
339+
active_models_map_for_save = {}
340+
341+
for subgraph_path in prev_incorrect_subgraphs:
342+
# Parse model name and subgraph index
343+
model_name_with_subgraph_idx = subgraph_path.rstrip("/").split(os.sep)[-1]
344+
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
345+
subgraph_idx = int(model_name_with_subgraph_idx.split("_")[-1])
346+
347+
if model_name not in prev_active_models_map:
348+
continue
349+
350+
active_models_map_for_save[model_name] = prev_active_models_map[model_name]
351+
352+
# Reconstruct previous subgraph size to locate the failing segment
353+
prev_split_positions = sorted(prev_used_splits.get(model_name, []))
354+
subgraph_size = reconstruct_subgraph_size(prev_split_positions)
355+
356+
if subgraph_idx >= len(subgraph_size):
357+
print(
358+
f"[WARN] Subgraph index {subgraph_idx} out of bounds for {model_name}"
359+
)
360+
continue
361+
362+
split_positions = calculate_split_positions_for_subgraph(
363+
subgraph_size[subgraph_idx]
331364
)
332-
print(
333-
f"[Init] Resuming from Pass_{current_pass_id - 1} (Dir: {prev_pass_dir})..."
365+
366+
if model_name not in tasks_map:
367+
tasks_map[model_name] = {
368+
"subgraph_path": subgraph_path,
369+
"original_path": prev_active_models_map[model_name],
370+
"subgraph_size": subgraph_size[subgraph_idx],
371+
"split_positions": split_positions,
372+
}
373+
374+
return tasks_map, active_models_map_for_save, max_subgraph_size
375+
376+
377+
def execute_decomposition_phase(tasks_map, framework, pass_work_dir, should_run):
378+
"""Executes the decomposition phase (Phase 1)."""
379+
failed_decomposition = []
380+
final_used_splits_map = {}
381+
382+
if not should_run or not tasks_map:
383+
return failed_decomposition, final_used_splits_map
384+
385+
print("\n--- Phase 1: Decomposition ---", flush=True)
386+
387+
decomposed_samples_dir = os.path.join(
388+
pass_work_dir, "samples" if framework == "torch" else "paddle_samples"
389+
)
390+
os.makedirs(decomposed_samples_dir, exist_ok=True)
391+
print(f"decomposed_samples_dir: {decomposed_samples_dir}")
392+
393+
for model_name, task_info in tasks_map.items():
394+
original_path = task_info["original_path"]
395+
split_positions = sorted(list(task_info["split_positions"]))
396+
final_used_splits_map[model_name] = split_positions
397+
398+
rectified_model_path = get_rectfied_model_path(original_path)
399+
400+
success = run_decomposer(
401+
framework, rectified_model_path, decomposed_samples_dir, split_positions
334402
)
403+
if not success:
404+
failed_decomposition.append(rectified_model_path)
335405

336-
prev_config = load_decompose_config(prev_pass_dir)
337-
prev_active_models_map = prev_config.get("active_models_map", {})
338-
prev_used_splits = prev_config.get("split_positions_map", {})
339-
prev_incorrect_subgraphs = prev_config.get("incorrect_models", [])
406+
num_samples = count_samples(decomposed_samples_dir)
407+
print(f"- number of graphs: {len(tasks_map)} -> {num_samples}", flush=True)
340408

341-
# Load previous max size as fallback for calculation
342-
prev_max_size = prev_config.get("max_subgraph_size", args.max_subgraph_size)
343-
max_subgraph_size = prev_max_size
409+
if failed_decomposition:
410+
print(f"[WARN] {len(failed_decomposition)} models failed to decompose.")
344411

345-
if not prev_incorrect_subgraphs:
346-
print("[FINISHED] Debugging completed.")
347-
sys.exit(0)
412+
return failed_decomposition, final_used_splits_map
348413

349-
print("[Analysis] Refining splits based on previous incorrect models ...")
350414

351-
for subgraph_path in prev_incorrect_subgraphs:
352-
print(f"- subgraph_path: {subgraph_path}")
353-
model_name_with_subgraph_idx = subgraph_path.rstrip("/").split(os.sep)[-1]
354-
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
355-
subgraph_idx = int(model_name_with_subgraph_idx.split("_")[-1])
356-
print(f"- model_name: {model_name}, subgraph_idx: {subgraph_idx}")
415+
def print_final_summary(next_round_models, real_subgraph_size):
416+
"""Prints the final suggestion/result."""
417+
print("\n" + "=" * 80)
418+
if next_round_models and real_subgraph_size > 1:
419+
print(f">>> [SUGGESTION] Issues remain (Count: {len(next_round_models)}).")
420+
print(">>> Please start next round decomposition test (Run this script again).")
421+
elif next_round_models and real_subgraph_size <= 1:
422+
print(">>> [FAILURE] Minimal granularity reached, but errors persist.")
423+
else:
424+
print(">>> [SUCCESS] Debugging converged.")
425+
print("=" * 80)
357426

358-
assert model_name in prev_active_models_map
359-
active_models_map_for_save[model_name] = prev_active_models_map[model_name]
360427

361-
# Reconstruct previous subgraph size to locate the failing segment
362-
prev_split_positions = sorted(prev_used_splits.get(model_name, []))
363-
subgraph_size = reconstruct_subgraph_size(prev_split_positions)
364-
assert subgraph_idx < len(
365-
subgraph_size
366-
), f"subgraph_idx {subgraph_idx} is out of bounds for {model_name} (previous split_positions: {prev_split_positions})"
428+
def main(args):
429+
task_controller = TaskController(args)
430+
base_output_dir = task_controller.root_output_dir
431+
current_pass_id = task_controller.current_pass_id
432+
433+
print("=" * 80)
434+
print(f" GraphNet Auto-Debugger | ROUND: PASS_{current_pass_id}")
435+
print("=" * 80)
367436

368-
split_positions = calculate_split_positions_for_subgraph(
369-
subgraph_size[subgraph_idx]
370-
)
371-
if model_name not in tasks_map:
372-
tasks_map[model_name] = {
373-
"subgraph_path": subgraph_path,
374-
"original_path": prev_active_models_map[model_name],
375-
"subgraph_size": subgraph_size[subgraph_idx],
376-
"split_positions": split_positions,
377-
}
378-
else:
379-
continue
380-
381-
# Recalculate based on current map to ensure log accuracy
437+
# --- Step 1: Prepare Tasks ---
438+
if current_pass_id == 0:
439+
(
440+
tasks_map,
441+
active_models_map_for_save,
442+
max_subgraph_size,
443+
) = generate_initial_tasks(
444+
args.log_file, args.tolerance, args.max_subgraph_size
445+
)
446+
else:
447+
(
448+
tasks_map,
449+
active_models_map_for_save,
450+
max_subgraph_size,
451+
) = generate_refined_tasks(
452+
base_output_dir, current_pass_id, args.max_subgraph_size
453+
)
454+
455+
# Recalculate size for logging
382456
real_subgraph_size = calculate_current_subgraph_size(tasks_map, max_subgraph_size)
383457
print(f"[INFO] Current Subgraph Size: {real_subgraph_size}")
384458
print(f"[INFO] Models to Process: {len(tasks_map)}")
385-
for model_name, task_info in tasks_map.items():
386-
original_path = task_info["original_path"]
387-
print(f"- {original_path}")
388459

389-
if not tasks_map:
460+
if not tasks_map and current_pass_id > 0:
390461
print("[FINISHED] No models need processing.")
391462
sys.exit(0)
392463

@@ -396,57 +467,17 @@ def main(args):
396467
os.makedirs(pass_work_dir, exist_ok=True)
397468

398469
# --- Step 3: Decomposition ---
399-
need_decompose = (
400-
True
401-
if task_controller.task_scheduler["run_decomposer"] and len(tasks_map) > 0
402-
else False
470+
failed_decomposition, final_used_splits_map = execute_decomposition_phase(
471+
tasks_map,
472+
args.framework,
473+
pass_work_dir,
474+
task_controller.task_scheduler["run_decomposer"],
403475
)
404-
if need_decompose:
405-
print("\n--- Phase 1: Decomposition ---", flush=True)
406-
407-
failed_decomposition = []
408-
final_used_splits_map = {}
409-
if need_decompose:
410-
decomposed_samples_dir = os.path.join(
411-
pass_work_dir, "samples" if args.framework == "torch" else "paddle_samples"
412-
)
413-
os.makedirs(decomposed_samples_dir, exist_ok=True)
414-
print(f"decomposed_samples_dir: {decomposed_samples_dir}")
415-
416-
for model_name, task_info in tasks_map.items():
417-
original_path = task_info["original_path"]
418-
split_positions = sorted(list(task_info["split_positions"]))
419-
420-
final_used_splits_map[model_name] = split_positions
421-
422-
rectied_model_path = get_rectfied_model_path(original_path)
423-
print(f"original_path: {original_path}")
424-
print(f"rectied_model_path: {rectied_model_path}")
425-
assert os.path.exists(
426-
rectied_model_path
427-
), f"{rectied_model_path} does not exist."
428-
429-
success = run_decomposer(
430-
args.framework,
431-
rectied_model_path,
432-
decomposed_samples_dir,
433-
split_positions,
434-
)
435-
if not success:
436-
failed_decomposition.append(rectied_model_path)
437-
438-
num_decomposed_samples = count_samples(decomposed_samples_dir)
439-
print(
440-
f"- number of graphs: {len(tasks_map)} -> {num_decomposed_samples}",
441-
flush=True,
442-
)
443-
if failed_decomposition:
444-
print(f"[WARN] {len(failed_decomposition)} models failed to decompose.")
445476

446477
# --- Step 4: Testing ---
478+
pass_log_path = os.path.join(pass_work_dir, "batch_test_result.log")
447479
if task_controller.task_scheduler["run_evaluation"]:
448480
print("\n--- Phase 2: Batch Testing ---")
449-
pass_log_path = os.path.join(pass_work_dir, "batch_test_result.log")
450481
run_evaluation(args.framework, args.test_config, pass_work_dir, pass_log_path)
451482

452483
# --- Step 5: Analysis ---
@@ -466,15 +497,7 @@ def main(args):
466497
failed_decomposition,
467498
)
468499

469-
print("\n" + "=" * 80)
470-
if next_round_models and real_subgraph_size > 1:
471-
print(f">>> [SUGGESTION] Issues remain (Count: {len(next_round_models)}).")
472-
print(">>> Please start next round decomposition test (Run this script again).")
473-
elif next_round_models and real_subgraph_size <= 1:
474-
print(">>> [FAILURE] Minimal granularity reached, but errors persist.")
475-
else:
476-
print(">>> [SUCCESS] Debugging converged.")
477-
print("=" * 80)
500+
print_final_summary(next_round_models, real_subgraph_size)
478501

479502

480503
if __name__ == "__main__":

0 commit comments

Comments
 (0)