Skip to content

Commit 55b6bd2

Browse files
committed
Update to use the original model_path to decompose.
1 parent 53b0e4f commit 55b6bd2

File tree

3 files changed

+212
-469
lines changed

3 files changed

+212
-469
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 210 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -85,29 +85,40 @@ def count_samples(samples_dir):
8585
return num_samples
8686

8787

88-
def load_prev_config(pass_id: int, output_dir: str) -> Dict[str, Any]:
88+
def get_decompose_config_path(output_dir: str) -> str:
89+
"""Returns the full path to the decompose configuration file."""
90+
return os.path.join(output_dir, "decompose_config.json")
91+
92+
93+
def load_decompose_config(pass_id: int, output_dir: str) -> Dict[str, Any]:
8994
"""Loads the configuration file from the previous pass."""
9095
prev_dir = os.path.join(output_dir, f"pass_{pass_id - 1}")
91-
config_path = os.path.join(prev_dir, "decompose_config.json")
96+
config_path = get_decompose_config_path(prev_dir)
97+
9298
if not os.path.exists(config_path):
9399
raise FileNotFoundError(f"Missing configuration file: {config_path}")
94100
with open(config_path, "r") as f:
95101
return json.load(f)
96102

97103

98-
def save_current_config(
104+
def save_decompose_config(
99105
work_dir: str,
100-
current_max_size: int,
101-
incorrect_models: Union[List[str], Set[str]],
102-
failed_models: List[str],
106+
max_subgraph_size: int,
107+
incorrect_paths: Union[List[str], Set[str]],
108+
active_models_map: Dict[str, str],
109+
split_positions_map: Dict[str, List[int]],
110+
failed_decomposition_models: Union[List[str], Set[str]],
103111
):
104-
"""Saves the current state."""
112+
"""Saves the current state to a JSON file."""
105113
config = {
106-
"current_max_subgraph_size": current_max_size,
107-
"incorrect_models": list(incorrect_models),
108-
"failed_extraction_models": list(failed_models),
114+
"max_subgraph_size": max_subgraph_size,
115+
"incorrect_models": list(incorrect_paths),
116+
"active_models_map": active_models_map,
117+
"split_positions_map": split_positions_map,
118+
"failed_decomposition_models": list(failed_decomposition_models),
109119
}
110-
config_path = os.path.join(work_dir, "decompose_config.json")
120+
config_path = get_decompose_config_path(work_dir)
121+
111122
with open(config_path, "w") as f:
112123
json.dump(config, f, indent=4)
113124
print(f"[INFO] State saved to: {config_path}")
@@ -123,15 +134,10 @@ def run_decomposer(
123134
framework: str,
124135
model_path: str,
125136
output_dir: str,
126-
max_subgraph_size: int,
137+
split_positions: List[int],
127138
) -> bool:
128139
"""Decomposes a single model."""
129140

130-
upper_bound = 4096
131-
split_positions = list(
132-
range(max_subgraph_size, upper_bound + max_subgraph_size, max_subgraph_size)
133-
)
134-
135141
graphnet_root = path_utils.get_graphnet_root()
136142
model_name = get_model_name_with_subgraph_tag(model_path)
137143
decorator_config = {
@@ -142,7 +148,7 @@ def run_decomposer(
142148
"custom_extractor_config": {
143149
"output_dir": output_dir,
144150
"split_positions": split_positions,
145-
"group_head_and_tail": True,
151+
"group_head_and_tail": False,
146152
"chain_style": False,
147153
},
148154
},
@@ -151,9 +157,9 @@ def run_decomposer(
151157
json.dumps(decorator_config).encode()
152158
).decode()
153159

154-
print(
155-
f"- [Decomposing] {model_name} (max_subgraph_size={max_subgraph_size}, split_positions={split_positions})"
156-
)
160+
print(f"[Decomposing] {model_path}")
161+
print(f"[Strategy] split_positions: {split_positions}")
162+
157163
cmd = [
158164
sys.executable,
159165
"-m",
@@ -196,8 +202,8 @@ def run_evaluation(
196202
for item in (f"--{key}", str(value))
197203
]
198204

199-
print(f" [Batch Testing] Logging to: {log_path}")
200-
print(f" [Command] {' '.join(cmd)}")
205+
print(f"[Batch Testing] Logging to: {log_path}")
206+
print(f"[Command] {' '.join(cmd)}")
201207

202208
os.makedirs(os.path.dirname(log_path), exist_ok=True)
203209
with open(log_path, "w") as f:
@@ -209,6 +215,47 @@ def run_evaluation(
209215
sys.exit(proc.returncode)
210216

211217

218+
def reconstruct_subgraph_size(split_positions: List[int]) -> List[tuple]:
219+
"""Reconstructs subgraph size based on sorted split positions."""
220+
full_splits = sorted(list(set(split_positions)))
221+
222+
subgraph_size = []
223+
# Needs at least 2 points to form an subgraph size
224+
if len(full_splits) < 2:
225+
return []
226+
227+
for i in range(len(full_splits) - 1):
228+
subgraph_size.append((full_splits[i], full_splits[i + 1]))
229+
230+
return subgraph_size
231+
232+
233+
def calculate_current_subgraph_size(
234+
tasks_map: Dict[str, Dict], fallback_size: int
235+
) -> int:
236+
"""Calculates the current subgraph size from generated tasks."""
237+
current_subgraph_size = float("inf")
238+
found_splits = False
239+
240+
for _, info in tasks_map.items():
241+
splits = sorted(list(info["split_positions"]))
242+
243+
if len(splits) < 2:
244+
continue
245+
246+
found_splits = True
247+
for i in range(len(splits) - 1):
248+
diff = splits[i + 1] - splits[i]
249+
if diff > 0:
250+
current_subgraph_size = min(current_subgraph_size, diff)
251+
252+
return (
253+
int(current_subgraph_size)
254+
if found_splits and current_subgraph_size != float("inf")
255+
else fallback_size
256+
)
257+
258+
212259
def main(args):
213260
task_controller = TaskController(args)
214261
base_output_dir = task_controller.root_output_dir
@@ -218,28 +265,119 @@ def main(args):
218265
print(f" GraphNet Auto-Debugger | ROUND: PASS_{current_pass_id}")
219266
print("=" * 80)
220267

221-
# --- Step 1: Initialize State ---
222-
target_models = []
223-
current_max_size = args.max_subgraph_size
268+
tasks_map = {}
269+
active_models_map_for_save = {}
270+
kMaxGraphSize = 4096
271+
272+
# Initialize using the argument passed from bash
273+
max_subgraph_size = args.max_subgraph_size
274+
224275
if current_pass_id == 0:
225276
print(f"[Init] Pass 0: Reading from log file: {args.log_file}")
226-
current_max_size = args.max_subgraph_size
227-
target_models = get_incorrect_models(args.tolerance, args.log_file)
228-
else:
229-
print(f"[Init] Resuming from Pass {current_pass_id - 1}...")
230-
prev_config = load_prev_config(current_pass_id, base_output_dir)
231-
target_models = prev_config.get("incorrect_models", [])
277+
initial_failures = get_incorrect_models(args.tolerance, args.log_file)
232278

233-
prev_size = prev_config.get("current_max_subgraph_size", 2048)
234-
current_max_size = max(1, prev_size // 2)
279+
# Dynamic generation based on step size (args.max_subgraph_size)
280+
initial_splits = list(range(0, kMaxGraphSize + 1, max_subgraph_size))
235281

236-
print(f"[INFO] current max_subgraph_size: {current_max_size}")
237-
print(f"[INFO] number of incorrect models: {len(target_models)}")
238-
for model_path in target_models:
239-
print(f"- {model_path}")
282+
for path in initial_failures:
283+
name = os.path.basename(path.rstrip("/"))
284+
active_models_map_for_save[name] = path
285+
tasks_map[name] = {
286+
"original_path": path,
287+
"split_positions": set(initial_splits),
288+
}
289+
else:
290+
prev_pass_dir = os.path.join(base_output_dir, f"pass_{current_pass_id - 1}")
291+
print(
292+
f"[Init] Resuming from Pass {current_pass_id - 1} (Dir: {prev_pass_dir})..."
293+
)
240294

241-
if not target_models:
242-
print(f"[FINISHED] Debugging completed.")
295+
prev_config = load_decompose_config(current_pass_id, base_output_dir)
296+
prev_map = prev_config.get("active_models_map", {})
297+
298+
prev_used_splits = prev_config.get("split_positions_map", {})
299+
prev_incorrect_subgraphs = prev_config.get("incorrect_models", [])
300+
301+
# Load previous max size as fallback for calculation
302+
prev_max_size = prev_config.get("max_subgraph_size", args.max_subgraph_size)
303+
max_subgraph_size = prev_max_size
304+
305+
if not prev_incorrect_subgraphs:
306+
print(f"[FINISHED] Debugging completed.")
307+
sys.exit(0)
308+
309+
print(f"[Analysis] Refining splits based on failures...")
310+
311+
for sub_path in prev_incorrect_subgraphs:
312+
parts = sub_path.rstrip("/").split("/")
313+
if len(parts) < 2:
314+
continue
315+
316+
subgraph_dirname = parts[-1]
317+
model_name = parts[-2]
318+
319+
if model_name in prev_map:
320+
active_models_map_for_save[model_name] = prev_map[model_name]
321+
if model_name not in tasks_map:
322+
tasks_map[model_name] = {
323+
"original_path": prev_map[model_name],
324+
"split_positions": set(),
325+
}
326+
else:
327+
continue
328+
329+
try:
330+
sub_idx = int(subgraph_dirname.split("_")[-1])
331+
except ValueError:
332+
continue
333+
334+
# 1. Reconstruct previous subgraph size to locate the failing segment
335+
old_split_position = sorted(prev_used_splits.get(model_name, []))
336+
subgraph_size = reconstruct_subgraph_size(old_split_position)
337+
338+
if sub_idx >= len(subgraph_size):
339+
print(
340+
f"[WARN] Index {sub_idx} out of bounds for {model_name} (old split position: {old_split_position})"
341+
)
342+
continue
343+
344+
# 2. Get the specific failing subgraph size [Start, End]
345+
fail_start, fail_end = subgraph_size[sub_idx]
346+
347+
# though intervals logic usually handles this via float('inf') replacement if used.
348+
if fail_end == float("inf"):
349+
fail_end = kMaxGraphSize
350+
351+
# Dynamic step calculation
352+
subgraph_size_len = fail_end - fail_start
353+
new_step = subgraph_size_len // 2
354+
355+
if new_step < 1:
356+
new_step = subgraph_size_len
357+
358+
# 3. Calculate Midpoint
359+
mid_point = fail_start + new_step
360+
361+
# 4. Add split positions
362+
if mid_point > fail_start and mid_point < fail_end:
363+
tasks_map[model_name]["split_positions"].update(
364+
[int(fail_start), int(mid_point), int(fail_end)]
365+
)
366+
else:
367+
tasks_map[model_name]["split_positions"].update(
368+
[int(fail_start), int(fail_end)]
369+
)
370+
371+
# Recalculate based on current map to ensure log accuracy
372+
real_subgraph_size = calculate_current_subgraph_size(tasks_map, max_subgraph_size)
373+
print(f"[INFO] Current Subgraph Size: {real_subgraph_size}")
374+
print(f"[INFO] Models to Process: {len(tasks_map)}")
375+
for model_name, task_info in tasks_map.items():
376+
original_path = task_info["original_path"]
377+
print(f"- {original_path}")
378+
379+
if not tasks_map:
380+
print(f"[FINISHED] No models need processing.")
243381
sys.exit(0)
244382

245383
# --- Step 2: Prepare Workspace ---
@@ -250,42 +388,48 @@ def main(args):
250388
# --- Step 3: Decomposition ---
251389
need_decompose = (
252390
True
253-
if task_controller.task_scheduler["run_decomposer"] and len(target_models) > 0
391+
if task_controller.task_scheduler["run_decomposer"] and len(tasks_map) > 0
254392
else False
255393
)
256394
if need_decompose:
257395
print("\n--- Phase 1: Decomposition ---", flush=True)
258-
failed_extraction = []
396+
397+
failed_decomposition = []
398+
final_used_splits_map = {}
259399
while need_decompose:
260400
decomposed_samples_dir = os.path.join(
261401
pass_work_dir, "samples" if args.framework == "torch" else "paddle_samples"
262402
)
263403
os.makedirs(decomposed_samples_dir, exist_ok=True)
264404

265-
for idx, model_path in enumerate(target_models):
266-
rectied_model_path = get_rectfied_model_path(model_path)
405+
for model_name, task_info in tasks_map.items():
406+
original_path = task_info["original_path"]
407+
split_positions = sorted(list(task_info["split_positions"]))
408+
final_used_splits_map[model_name] = split_positions
409+
410+
rectied_model_path = get_rectfied_model_path(original_path)
267411
assert os.path.exists(
268412
rectied_model_path
269413
), f"{rectied_model_path} does not exist."
270414

271-
os.makedirs(decomposed_samples_dir, exist_ok=True)
272415
success = run_decomposer(
273416
args.framework,
274417
rectied_model_path,
275418
decomposed_samples_dir,
276-
current_max_size,
419+
split_positions,
277420
)
278421
if not success:
279-
failed_extraction.append(rectied_model_path)
422+
failed_decomposition.append(rectied_model_path)
423+
280424
num_decomposed_samples = count_samples(decomposed_samples_dir)
281425
print(
282-
f"- number of graphs: {len(target_models)} -> {num_decomposed_samples}",
426+
f"- number of graphs: {len(tasks_map)} -> {num_decomposed_samples}",
283427
flush=True,
284428
)
285-
if failed_extraction:
286-
print(f"[WARN] {len(failed_extraction)} models failed to decompose.")
429+
if failed_decomposition:
430+
print(f"[WARN] {len(failed_decomposition)} models failed to decompose.")
287431

288-
if num_decomposed_samples == len(target_models):
432+
if not failed_decomposition and num_decomposed_samples == len(tasks_map):
289433
need_decompose = True
290434
shutil.rmtree(decomposed_samples_dir)
291435
os.makedirs(decomposed_samples_dir, exist_ok=True)
@@ -300,26 +444,31 @@ def main(args):
300444
run_evaluation(args.framework, args.test_config, pass_work_dir, pass_log_path)
301445

302446
# --- Step 5: Analysis ---
447+
print("\n--- Phase 3: Analysis ---")
448+
next_round_models = set(get_incorrect_models(args.tolerance, pass_log_path))
449+
print(f"[Result] Found {len(next_round_models)} incorrect subgraphs.")
450+
303451
next_round_models = set()
304452
if task_controller.task_scheduler["post_analysis"]:
305453
print("\n--- Phase 3: Analysis ---")
306-
next_round_models = set()
307-
try:
308-
next_round_models = set(get_incorrect_models(args.tolerance, pass_log_path))
309-
print(f" [Result] Found {len(next_round_models)} incorrect subgraphs.")
310-
except Exception as e:
311-
print(f" [ERROR] Log analysis failed: {e}")
454+
next_round_models = set(get_incorrect_models(args.tolerance, pass_log_path))
455+
print(f"[Result] Found {len(next_round_models)} incorrect subgraphs.")
312456

313457
# --- Step 6: Save State ---
314-
save_current_config(
315-
pass_work_dir, current_max_size, next_round_models, failed_extraction
458+
save_decompose_config(
459+
pass_work_dir,
460+
real_subgraph_size,
461+
next_round_models,
462+
active_models_map_for_save,
463+
final_used_splits_map,
464+
failed_decomposition,
316465
)
317466

318467
print("\n" + "=" * 80)
319-
if next_round_models and current_max_size > 1:
468+
if next_round_models and real_subgraph_size > 1:
320469
print(f">>> [SUGGESTION] Issues remain (Count: {len(next_round_models)}).")
321470
print(">>> Please start next round decomposition test (Run this script again).")
322-
elif next_round_models and current_max_size <= 1:
471+
elif next_round_models and real_subgraph_size <= 1:
323472
print(f">>> [FAILURE] Minimal granularity reached, but errors persist.")
324473
else:
325474
print(f">>> [SUCCESS] Debugging converged.")

0 commit comments

Comments
 (0)