Skip to content

Commit 8027793

Browse files
committed
Optimize the definition of decompose config and fix the config saving of test_target_device.
1 parent bd84a36 commit 8027793

File tree

1 file changed

+92
-74
lines changed

1 file changed

+92
-74
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 92 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -123,17 +123,16 @@ def save_decompose_config(
123123
):
124124
"""Saves the current state to a JSON file."""
125125

126-
active_models_map = {}
127-
split_positions_map = {}
126+
tasks_map_copy = {}
128127
for model_name, task_info in tasks_map.items():
129-
active_models_map[model_name] = task_info["original_path"]
130-
split_positions_map[model_name] = task_info["split_positions"]
128+
tasks_map_copy[model_name] = {}
129+
for key in ["original_path", "split_positions"]:
130+
tasks_map_copy[model_name][key] = task_info[key]
131131

132132
config = {
133133
"max_subgraph_size": max_subgraph_size,
134134
"incorrect_models": list(incorrect_paths),
135-
"active_models_map": active_models_map,
136-
"split_positions_map": split_positions_map,
135+
"tasks_map": tasks_map_copy,
137136
"failed_decomposition_models": list(failed_decomposition_models),
138137
}
139138
config_path = get_decompose_config_path(work_dir)
@@ -149,7 +148,7 @@ def get_model_name_with_subgraph_tag(model_path):
149148
return f"{fields[-2]}_{fields[-1]}" if re.match(pattern, fields[-1]) else fields[-1]
150149

151150

152-
def run_naive_decomposer(
151+
def run_decomposer_for_single_model(
153152
framework: str,
154153
model_path: str,
155154
output_dir: str,
@@ -201,6 +200,32 @@ def run_naive_decomposer(
201200
return True
202201

203202

203+
def run_decomposer_for_multi_models(
204+
framework, tasks_map, decomposed_samples_dir, max_subgraph_size
205+
):
206+
failed_decomposition = []
207+
208+
for model_name, task_info in tasks_map.items():
209+
print(f"[Decomposition] max_subgraph_size: {max_subgraph_size}")
210+
original_path = task_info["original_path"]
211+
split_positions = calculate_split_positions_for_subgraph(
212+
task_info["subgraph_size"], max_subgraph_size
213+
)
214+
task_info["split_positions"] = split_positions
215+
216+
rectified_model_path = get_rectfied_model_path(original_path)
217+
assert os.path.exists(
218+
rectified_model_path
219+
), f"{rectified_model_path} does not exist."
220+
221+
success = run_decomposer_for_single_model(
222+
framework, rectified_model_path, decomposed_samples_dir, split_positions
223+
)
224+
if not success:
225+
failed_decomposition.append(rectified_model_path)
226+
return tasks_map, failed_decomposition
227+
228+
204229
def run_evaluation(
205230
framework: str, test_cmd_b64: str, work_dir: str, log_path: str
206231
) -> int:
@@ -283,9 +308,8 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
283308
print(f"[Init] Resuming from Pass_{current_pass_id - 1} (Dir: {prev_pass_dir})...")
284309

285310
prev_config = load_decompose_config(prev_pass_dir)
286-
prev_active_models_map = prev_config.get("active_models_map", {})
287-
prev_split_positions_map = prev_config.get("split_positions_map", {})
288311
prev_incorrect_subgraphs = prev_config.get("incorrect_models", [])
312+
prev_tasks_map = prev_config.get("tasks_map", {})
289313

290314
# Load previous max size as fallback
291315
prev_max_subgraph_size = prev_config.get("max_subgraph_size")
@@ -294,20 +318,18 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
294318
if not prev_incorrect_subgraphs:
295319
return {}, max_subgraph_size
296320

297-
print("[Analysis] Refining splits based on previous incorrect models ...")
298-
299321
tasks_map = {}
300322
for subgraph_path in prev_incorrect_subgraphs:
301323
# Parse model name and subgraph index
302324
model_name_with_subgraph_idx = subgraph_path.rstrip("/").split(os.sep)[-1]
303325
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
304326
subgraph_idx = int(model_name_with_subgraph_idx.split("_")[-1])
305-
print(f"subgraph_path: {subgraph_path}")
306-
print(f"model_name: {model_name}, subgraph_idx: {subgraph_idx}")
307-
assert model_name in prev_active_models_map
327+
328+
assert model_name in prev_tasks_map
329+
pre_task_for_model = prev_tasks_map[model_name]
308330

309331
# Reconstruct previous subgraph size to locate the failing segment
310-
prev_split_positions = prev_split_positions_map.get(model_name, [])
332+
prev_split_positions = pre_task_for_model.get("split_positions", [])
311333
subgraph_size = reconstruct_subgraph_size(prev_split_positions)
312334
assert subgraph_idx < len(
313335
subgraph_size
@@ -316,49 +338,58 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
316338
if model_name not in tasks_map:
317339
tasks_map[model_name] = {
318340
"subgraph_path": subgraph_path,
319-
"original_path": prev_active_models_map[model_name],
341+
"original_path": pre_task_for_model["original_path"],
320342
"subgraph_size": subgraph_size[subgraph_idx],
321343
"split_positions": set(),
322344
}
323345

324346
return tasks_map, max_subgraph_size
325347

326348

327-
def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, pass_work_dir):
328-
"""Executes the decomposition phase (Phase 1)."""
329-
failed_decomposition = []
349+
def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
350+
if current_pass_id == 0:
351+
tasks_map, max_subgraph_size = generate_initial_tasks(args)
352+
else:
353+
tasks_map, max_subgraph_size = generate_refined_tasks(
354+
base_output_dir, current_pass_id
355+
)
356+
357+
print(f"[INFO] initial max_subgraph_size: {max_subgraph_size}")
358+
print(f"[INFO] number of incorrect models: {len(tasks_map)}")
359+
for model_name, task_info in tasks_map.items():
360+
original_path = task_info["original_path"]
361+
print(f"- {original_path}")
362+
363+
if not tasks_map:
364+
print("[FINISHED] No models need processing.")
365+
sys.exit(0)
366+
367+
if max_subgraph_size <= 0:
368+
print(
369+
f"[FINISHED] Cannot decompose with max_subgraph_size {max_subgraph_size}."
370+
)
371+
sys.exit(0)
330372

373+
return tasks_map, max_subgraph_size
374+
375+
376+
def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspace):
377+
"""Executes the decomposition phase."""
378+
379+
failed_decomposition = []
331380
need_decompose = True if len(tasks_map) > 0 else False
332-
if need_decompose:
333-
print("\n--- Phase 1: Decomposition ---", flush=True)
334381

335382
while need_decompose:
336383
decomposed_samples_dir = os.path.join(
337-
pass_work_dir, "samples" if framework == "torch" else "paddle_samples"
384+
workspace, "samples" if framework == "torch" else "paddle_samples"
338385
)
339386
if not os.path.exists(decomposed_samples_dir):
340387
os.makedirs(decomposed_samples_dir, exist_ok=True)
341-
print(f"decomposed_samples_dir: {decomposed_samples_dir}")
342-
343-
for model_name, task_info in tasks_map.items():
344-
print(f"[Decomposition] max_subgraph_size: {max_subgraph_size}")
345-
original_path = task_info["original_path"]
346-
split_positions = calculate_split_positions_for_subgraph(
347-
task_info["subgraph_size"], max_subgraph_size
348-
)
349-
task_info["split_positions"] = split_positions
350-
351-
rectified_model_path = get_rectfied_model_path(original_path)
352-
assert os.path.exists(
353-
rectified_model_path
354-
), f"{rectified_model_path} does not exist."
355-
356-
success = run_naive_decomposer(
357-
framework, rectified_model_path, decomposed_samples_dir, split_positions
358-
)
359-
if not success:
360-
failed_decomposition.append(rectified_model_path)
388+
print(f"[Decomposition] decomposed_samples_dir: {decomposed_samples_dir}")
361389

390+
tasks_map, failed_decomposition = run_decomposer_for_multi_models(
391+
framework, tasks_map, decomposed_samples_dir, max_subgraph_size
392+
)
362393
num_decomposed_samples = count_samples(decomposed_samples_dir)
363394
print(
364395
f"[Decomposition] number of graphs: {len(tasks_map)} -> {num_decomposed_samples}",
@@ -387,8 +418,8 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, pass_wo
387418
return tasks_map, failed_decomposition, max_subgraph_size
388419

389420

390-
def print_final_summary(next_round_models, max_subgraph_size):
391-
"""Prints the final suggestion/result."""
421+
def print_summary_and_suggestion(next_round_models, max_subgraph_size):
422+
"""Print suggestion/result."""
392423
print("\n" + "=" * 80)
393424
if next_round_models and max_subgraph_size > 1:
394425
print(f">>> [SUGGESTION] Issues remain (Count: {len(next_round_models)}).")
@@ -409,59 +440,48 @@ def main(args):
409440
print(f" GraphNet Auto-Debugger | ROUND: PASS_{current_pass_id}")
410441
print("=" * 80)
411442

412-
# --- Step 1: Prepare Tasks ---
413-
if current_pass_id == 0:
414-
tasks_map, max_subgraph_size = generate_initial_tasks(args)
415-
else:
416-
tasks_map, max_subgraph_size = generate_refined_tasks(
417-
base_output_dir, current_pass_id
418-
)
419-
420-
print(f"[INFO] initial max_subgraph_size: {max_subgraph_size}")
421-
print(f"[INFO] number of incorrect models: {len(tasks_map)}")
422-
for model_name, task_info in tasks_map.items():
423-
original_path = task_info["original_path"]
424-
print(f"- {original_path}")
425-
426-
if not tasks_map:
427-
print("[FINISHED] No models need processing.")
428-
sys.exit(0)
429-
if max_subgraph_size <= 0:
430-
print(
431-
f"[FINISHED] Cannot decompose with max_subgraph_size {max_subgraph_size}."
432-
)
433-
sys.exit(0)
434-
435-
# --- Step 2: Prepare Workspace ---
443+
# --- Step 1: Prepare Tasks and Workspace ---
444+
tasks_map, max_subgraph_size = prepare_tasks_and_verify(
445+
args, current_pass_id, base_output_dir
446+
)
436447
pass_work_dir = get_decompose_workspace_path(base_output_dir, current_pass_id)
437448
if not os.path.exists(pass_work_dir):
438449
os.makedirs(pass_work_dir, exist_ok=True)
439450

440-
# --- Step 3: Decomposition ---
451+
# --- Step 2: Decomposition ---
441452
failed_decomposition = []
442453
if task_controller.task_scheduler["run_decomposer"]:
454+
print("\n--- Phase 1: Decomposition ---", flush=True)
443455
(
444456
tasks_map,
445457
failed_decomposition,
446458
max_subgraph_size,
447459
) = execute_decomposition_phase(
448460
max_subgraph_size, tasks_map, args.framework, pass_work_dir
449461
)
462+
else:
463+
config = load_decompose_config(pass_work_dir)
464+
max_subgraph_size = config["max_subgraph_size"]
465+
failed_decomposition = config["failed_decomposition_models"]
466+
tasks_map = config.get("tasks_map", {})
450467

451-
# --- Step 4: Testing ---
468+
# --- Step 3: Evaluation ---
452469
pass_log_path = os.path.join(pass_work_dir, "batch_test_result.log")
453470
if task_controller.task_scheduler["run_evaluation"]:
454471
print("\n--- Phase 2: Evaluation ---")
455472
run_evaluation(args.framework, args.test_config, pass_work_dir, pass_log_path)
456473

457-
# --- Step 5: Analysis ---
474+
# --- Step 4: Analysis ---
458475
next_round_models = set()
459476
if task_controller.task_scheduler["post_analysis"]:
460477
print("\n--- Phase 3: Analysis ---")
461478
next_round_models = get_incorrect_models(args.tolerance, pass_log_path)
462479
print(f"[Analysis] Found {len(next_round_models)} incorrect subgraphs.\n")
463480

464-
# --- Step 6: Save State ---
481+
print_summary_and_suggestion(next_round_models, max_subgraph_size)
482+
print()
483+
484+
# --- Step 5: Save States ---
465485
save_decompose_config(
466486
pass_work_dir,
467487
max_subgraph_size,
@@ -470,8 +490,6 @@ def main(args):
470490
failed_decomposition,
471491
)
472492

473-
print_final_summary(next_round_models, max_subgraph_size)
474-
475493

476494
if __name__ == "__main__":
477495
parser = argparse.ArgumentParser()

0 commit comments

Comments
 (0)