Skip to content

Commit 69bc5d7

Browse files
committed
Fix model_name and refine some codes.
1 parent becb25c commit 69bc5d7

File tree

1 file changed

+49
-39
lines changed

1 file changed

+49
-39
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import sys
3+
import re
34
import json
45
import base64
56
import shutil
@@ -9,7 +10,7 @@
910
from typing import List, Set, Dict, Any, Union
1011
import graph_net
1112
from graph_net.analysis_util import get_incorrect_models
12-
from graph_net import path_utils, test_compiler_util
13+
from graph_net import path_utils
1314

1415

1516
def convert_b64_string_to_json(b64str):
@@ -22,16 +23,16 @@ def __init__(self, args):
2223
self.test_config = convert_b64_string_to_json(args.test_config)
2324
assert "test_module_name" in self.test_config
2425

25-
test_module_name = self.test_config["test_module_name"]
26+
self.test_module_name = self.test_config["test_module_name"]
2627
max_pass_id = self._determine_max_pass_id(self.root_output_dir)
2728
self.current_pass_id = (
28-
max_pass_id if test_module_name == "test_target_device" else max_pass_id + 1
29-
)
30-
print(
31-
f"test_module_name: {test_module_name}, current_pass_id: {self.current_pass_id}"
29+
max_pass_id
30+
if self.test_module_name == "test_target_device"
31+
else max_pass_id + 1
3232
)
3333

34-
self._init_task_scheduler(test_module_name)
34+
self._init_task_scheduler(self.test_module_name)
35+
self._print()
3536

3637
def _determine_max_pass_id(self, output_dir: str) -> int:
3738
"""Scans the output directory to determine the next pass ID."""
@@ -71,6 +72,14 @@ def _init_task_scheduler(self, test_module_name):
7172
"post_analysis": True,
7273
}
7374

75+
def _print(self):
76+
print(
77+
f"[TaskController] test_module_name: {self.test_module_name}, current_pass_id: {self.current_pass_id}",
78+
flush=True,
79+
)
80+
print(f"[TaskController] task_scheduler: {self.task_scheduler}", flush=True)
81+
print()
82+
7483

7584
def get_rectfied_model_path(model_path):
7685
graphnet_root = path_utils.get_graphnet_root()
@@ -90,10 +99,13 @@ def get_decompose_config_path(output_dir: str) -> str:
9099
return os.path.join(output_dir, "decompose_config.json")
91100

92101

93-
def load_decompose_config(pass_id: int, output_dir: str) -> Dict[str, Any]:
102+
def get_decompose_workspace_path(output_dir, pass_id):
103+
return os.path.join(output_dir, f"pass_{pass_id}")
104+
105+
106+
def load_decompose_config(work_dir: str) -> Dict[str, Any]:
94107
"""Loads the configuration file from the previous pass."""
95-
prev_dir = os.path.join(output_dir, f"pass_{pass_id - 1}")
96-
config_path = get_decompose_config_path(prev_dir)
108+
config_path = get_decompose_config_path(work_dir)
97109

98110
if not os.path.exists(config_path):
99111
raise FileNotFoundError(f"Missing configuration file: {config_path}")
@@ -125,9 +137,9 @@ def save_decompose_config(
125137

126138

127139
def get_model_name_with_subgraph_tag(model_path):
128-
model_name = test_compiler_util.get_model_name(model_path)
129-
subgraph_tag = test_compiler_util.get_subgraph_tag(model_path)
130-
return f"{model_name}_{subgraph_tag}" if subgraph_tag else model_name
140+
fields = model_path.rstrip("/").split(os.sep)
141+
pattern = rf"^subgraph(_\d+)?$"
142+
return f"{fields[-2]}_{fields[-1]}" if re.match(pattern, fields[-1]) else fields[-1]
131143

132144

133145
def run_decomposer(
@@ -287,14 +299,15 @@ def main(args):
287299
"split_positions": set(initial_splits),
288300
}
289301
else:
290-
prev_pass_dir = os.path.join(base_output_dir, f"pass_{current_pass_id - 1}")
302+
prev_pass_dir = get_decompose_workspace_path(
303+
base_output_dir, current_pass_id - 1
304+
)
291305
print(
292-
f"[Init] Resuming from Pass {current_pass_id - 1} (Dir: {prev_pass_dir})..."
306+
f"[Init] Resuming from Pass_{current_pass_id - 1} (Dir: {prev_pass_dir})..."
293307
)
294308

295-
prev_config = load_decompose_config(current_pass_id, base_output_dir)
296-
prev_map = prev_config.get("active_models_map", {})
297-
309+
prev_config = load_decompose_config(prev_pass_dir)
310+
prev_tasks_map = prev_config.get("active_models_map", {})
298311
prev_used_splits = prev_config.get("split_positions_map", {})
299312
prev_incorrect_subgraphs = prev_config.get("incorrect_models", [])
300313

@@ -308,41 +321,35 @@ def main(args):
308321

309322
print(f"[Analysis] Refining splits based on failures...")
310323

311-
for sub_path in prev_incorrect_subgraphs:
312-
parts = sub_path.rstrip("/").split("/")
313-
if len(parts) < 2:
314-
continue
324+
for subgraph_path in prev_incorrect_subgraphs:
325+
print(f"- subgraph_path: {subgraph_path}")
326+
model_name_with_subgraph_idx = subgraph_path.rstrip("/").split(os.sep)[-1]
327+
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
328+
subgraph_idx = int(model_name_with_subgraph_idx.split("_")[-1])
329+
print(f"- model_name: {model_name}, subgraph_idx: {subgraph_idx}")
315330

316-
subgraph_dirname = parts[-1]
317-
model_name = parts[-2]
318-
319-
if model_name in prev_map:
320-
active_models_map_for_save[model_name] = prev_map[model_name]
331+
if model_name in prev_tasks_map:
332+
active_models_map_for_save[model_name] = prev_tasks_map[model_name]
321333
if model_name not in tasks_map:
322334
tasks_map[model_name] = {
323-
"original_path": prev_map[model_name],
335+
"original_path": prev_tasks_map[model_name],
324336
"split_positions": set(),
325337
}
326338
else:
327339
continue
328340

329-
try:
330-
sub_idx = int(subgraph_dirname.split("_")[-1])
331-
except ValueError:
332-
continue
333-
334341
# 1. Reconstruct previous subgraph size to locate the failing segment
335-
old_split_position = sorted(prev_used_splits.get(model_name, []))
336-
subgraph_size = reconstruct_subgraph_size(old_split_position)
342+
prev_split_positions = sorted(prev_used_splits.get(model_name, []))
343+
subgraph_size = reconstruct_subgraph_size(prev_split_positions)
337344

338-
if sub_idx >= len(subgraph_size):
345+
if subgraph_idx >= len(subgraph_size):
339346
print(
340-
f"[WARN] Index {sub_idx} out of bounds for {model_name} (old split position: {old_split_position})"
347+
f"[WARN] Index {subgraph_idx} out of bounds for {model_name} (previous split_positions: {prev_split_positions})"
341348
)
342349
continue
343350

344351
# 2. Get the specific failing subgraph size [Start, End]
345-
fail_start, fail_end = subgraph_size[sub_idx]
352+
fail_start, fail_end = subgraph_size[subgraph_idx]
346353

347354
# though intervals logic usually handles this via float('inf') replacement if used.
348355
if fail_end == float("inf"):
@@ -381,7 +388,7 @@ def main(args):
381388
sys.exit(0)
382389

383390
# --- Step 2: Prepare Workspace ---
384-
pass_work_dir = os.path.join(base_output_dir, f"pass_{current_pass_id}")
391+
pass_work_dir = get_decompose_workspace_path(base_output_dir, current_pass_id)
385392
if not os.path.exists(pass_work_dir):
386393
os.makedirs(pass_work_dir, exist_ok=True)
387394

@@ -401,6 +408,7 @@ def main(args):
401408
pass_work_dir, "samples" if args.framework == "torch" else "paddle_samples"
402409
)
403410
os.makedirs(decomposed_samples_dir, exist_ok=True)
411+
print(f"decomposed_samples_dir: {decomposed_samples_dir}")
404412

405413
for model_name, task_info in tasks_map.items():
406414
original_path = task_info["original_path"]
@@ -409,6 +417,8 @@ def main(args):
409417
final_used_splits_map[model_name] = split_positions
410418

411419
rectied_model_path = get_rectfied_model_path(original_path)
420+
print(f"original_path: {original_path}")
421+
print(f"rectied_model_path: {rectied_model_path}")
412422
assert os.path.exists(
413423
rectied_model_path
414424
), f"{rectied_model_path} does not exist."

0 commit comments

Comments
 (0)