Skip to content

Commit 19cefd5

Browse files
committed
Support separate testing of reference and target device.
1 parent 8c8e315 commit 19cefd5

File tree

2 files changed

+160
-107
lines changed

2 files changed

+160
-107
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 121 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,58 @@
1212
from graph_net import path_utils, test_compiler_util
1313

1414

15+
def convert_b64_string_to_json(b64str):
16+
return json.loads(base64.b64decode(b64str).decode("utf-8"))
17+
18+
19+
class TaskController:
20+
def __init__(self, args):
21+
self.root_output_dir = os.path.abspath(args.output_dir)
22+
self.pass_id = self._determine_current_pass_id(self.root_output_dir)
23+
self.test_config = convert_b64_string_to_json(args.test_config)
24+
assert "test_module_name" in self.test_config
25+
26+
self._init_task_scheduler(self.test_config["test_module_name"])
27+
28+
def _determine_current_pass_id(self, output_dir: str) -> int:
29+
"""Scans the output directory to determine the next pass ID."""
30+
if not os.path.exists(output_dir):
31+
return 0
32+
existing_passes = glob.glob(os.path.join(output_dir, "pass_*"))
33+
valid_ids = []
34+
for p in existing_passes:
35+
basename = os.path.basename(p)
36+
parts = basename.split("_")
37+
if len(parts) == 2 and parts[1].isdigit():
38+
valid_ids.append(int(parts[1]))
39+
return max(valid_ids) + 1 if valid_ids else 0
40+
41+
def _init_task_scheduler(self, test_module_name):
42+
assert test_module_name in [
43+
"test_compiler",
44+
"test_reference_device",
45+
"test_target_device",
46+
]
47+
if test_module_name == "test_compiler":
48+
self.task_scheduler = {
49+
"run_decomposer": True,
50+
"run_evaluation": True,
51+
"post_analysis": True,
52+
}
53+
elif test_module_name == "test_reference_device":
54+
self.task_scheduler = {
55+
"run_decomposer": True,
56+
"run_evaluation": True,
57+
"post_analysis": False,
58+
}
59+
elif test_module_name == "test_target_device":
60+
self.task_scheduler = {
61+
"run_decomposer": False,
62+
"run_evaluation": True,
63+
"post_analysis": True,
64+
}
65+
66+
1567
def get_rectfied_model_path(model_path):
1668
graphnet_root = path_utils.get_graphnet_root()
1769
return os.path.join(graphnet_root, model_path.split("GraphNet/")[-1])
@@ -25,20 +77,6 @@ def count_samples(samples_dir):
2577
return num_samples
2678

2779

28-
def determine_current_pass_id(output_dir: str) -> int:
29-
"""Scans the output directory to determine the next pass ID."""
30-
if not os.path.exists(output_dir):
31-
return 0
32-
existing_passes = glob.glob(os.path.join(output_dir, "pass_*"))
33-
valid_ids = []
34-
for p in existing_passes:
35-
basename = os.path.basename(p)
36-
parts = basename.split("_")
37-
if len(parts) == 2 and parts[1].isdigit():
38-
valid_ids.append(int(parts[1]))
39-
return max(valid_ids) + 1 if valid_ids else 0
40-
41-
4280
def load_prev_config(pass_id: int, output_dir: str) -> Dict[str, Any]:
4381
"""Loads the configuration file from the previous pass."""
4482
prev_dir = os.path.join(output_dir, f"pass_{pass_id - 1}")
@@ -78,32 +116,36 @@ def run_decomposer(
78116
model_path: str,
79117
output_dir: str,
80118
max_subgraph_size: int,
81-
decorator_config: Dict[str, Any],
82119
) -> bool:
83120
"""Decomposes a single model."""
84-
# 1. Calculate dynamic split positions
121+
85122
upper_bound = 4096
86123
split_positions = list(
87124
range(max_subgraph_size, upper_bound + max_subgraph_size, max_subgraph_size)
88125
)
89126

90-
# 2. Deep copy the template
91-
final_decorator_config = json.loads(json.dumps(decorator_config))
92-
93-
# 3. Locate the nested dictionary to inject values
94-
decorator_cfg = final_decorator_config["decorator_config"]
95-
127+
graphnet_root = path_utils.get_graphnet_root()
96128
model_name = get_model_name_with_subgraph_tag(model_path)
97-
decorator_cfg["name"] = model_name
98-
99-
custom_cfg = decorator_cfg.get("custom_extractor_config", {})
100-
custom_cfg["output_dir"] = output_dir
101-
custom_cfg["split_positions"] = split_positions
102-
103-
# 4. Encode and Run
104-
decorator_config_json = json.dumps(final_decorator_config)
105-
decorator_config_b64 = base64.b64encode(decorator_config_json.encode()).decode()
129+
decorator_config = {
130+
"decorator_path": f"{graphnet_root}/graph_net/{framework}/extractor.py",
131+
"decorator_config": {
132+
"name": model_name,
133+
"custom_extractor_path": f"{graphnet_root}/graph_net/{framework}/naive_graph_decomposer.py",
134+
"custom_extractor_config": {
135+
"output_dir": output_dir,
136+
"split_positions": split_positions,
137+
"group_head_and_tail": True,
138+
"chain_style": False,
139+
},
140+
},
141+
}
142+
decorator_config_b64 = base64.b64encode(
143+
json.dumps(decorator_config).encode()
144+
).decode()
106145

146+
print(
147+
f"- [Decomposing] {model_name} (max_subgraph_size={max_subgraph_size}, split_positions={split_positions})"
148+
)
107149
cmd = [
108150
sys.executable,
109151
"-m",
@@ -113,74 +155,64 @@ def run_decomposer(
113155
"--decorator-config",
114156
decorator_config_b64,
115157
]
116-
117-
print(
118-
f"- [Decomposing] {model_name} (max_subgraph_size={max_subgraph_size}, split_positions={split_positions})"
119-
)
120-
121158
result = subprocess.run(
122159
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
123160
)
124161
if result.returncode != 0:
125-
print(f"[ERROR] Decomposition failed for {model_path}\n{result.stderr}")
162+
print(
163+
f"[ERROR] Decomposition failed for {model_path}\n{result.stderr}",
164+
flush=True,
165+
)
126166
return False
127167
# print(result.stdout)
128168
return True
129169

130170

131-
def run_evaluation(test_cmd_b64: str, work_dir: str, log_path: str) -> int:
171+
def run_evaluation(
172+
framework: str, test_cmd_b64: str, work_dir: str, log_path: str
173+
) -> int:
132174
"""Executes the test command on the batch directory."""
133-
json_str = base64.b64decode(test_cmd_b64).decode("utf-8")
134-
cmd_config = json.loads(json_str)
135175

136-
# Check if the config follows the new structure: {"module_name": "...", "arguments": {...}}
137-
if "module_name" in cmd_config and "arguments" in cmd_config:
138-
target_module = cmd_config["module_name"]
139-
args_dict = cmd_config["arguments"]
140-
else:
141-
# Fallback for old format (flat dictionary), assuming default compiler test
142-
target_module = "graph_net.torch.test_compiler"
143-
args_dict = cmd_config
144-
145-
cmd = [sys.executable, "-m", target_module]
146-
147-
for key, value in args_dict.items():
148-
if key == "model_path":
149-
continue
150-
cmd.append(f"--{key}")
151-
cmd.append(str(value))
176+
test_config = convert_b64_string_to_json(test_cmd_b64)
177+
test_module_name = test_config["test_module_name"]
178+
test_module_arguments = test_config[f"{test_module_name}_arguments"]
179+
test_module_arguments["model-path"] = work_dir
180+
if test_module_name == "test_reference_device":
181+
test_module_arguments["reference-dir"] = os.path.join(
182+
work_dir, "reference_device_outputs"
183+
)
152184

153-
cmd.append("--model-path")
154-
cmd.append(work_dir)
185+
cmd = [sys.executable, "-m", f"graph_net.{framework}.{test_module_name}"] + [
186+
item
187+
for key, value in test_module_arguments.items()
188+
for item in (f"--{key}", str(value))
189+
]
155190

156191
print(f" [Batch Testing] Logging to: {log_path}")
157192
print(f" [Command] {' '.join(cmd)}")
158193

159194
os.makedirs(os.path.dirname(log_path), exist_ok=True)
160195
with open(log_path, "w") as f:
161196
proc = subprocess.run(cmd, stdout=f, stderr=subprocess.STDOUT, text=True)
162-
return proc.returncode
197+
if proc.returncode != 0:
198+
with open(log_path, "r") as f:
199+
content = f.read()
200+
print(f"[ERROR] test failed for {work_dir}\n{content}", flush=True)
201+
sys.exit(proc.returncode)
163202

164203

165-
# ==========================================================
166-
# Main Execution Flow
167-
# ==========================================================
168204
def main(args):
169-
base_output_dir = os.path.abspath(args.output_dir)
170-
current_pass_id = determine_current_pass_id(base_output_dir)
171-
172-
# Parse the Decorator Configuration Template
173-
tpl_str = base64.b64decode(args.decorator_config).decode("utf-8")
174-
decorator_template = json.loads(tpl_str)
205+
task_controller = TaskController(args)
206+
base_output_dir = task_controller.root_output_dir
207+
current_pass_id = task_controller.pass_id
175208

176209
print("=" * 80)
177210
print(f" GraphNet Auto-Debugger | ROUND: PASS_{current_pass_id}")
178211
print("=" * 80)
179212

180213
# --- Step 1: Initialize State ---
181214
target_models = []
182-
current_max_size = 2048
183-
215+
current_max_size = args.max_subgraph_size
184216
if current_pass_id == 0:
185217
print(f"[Init] Pass 0: Reading from log file: {args.log_file}")
186218
current_max_size = args.max_subgraph_size
@@ -209,8 +241,13 @@ def main(args):
209241
os.makedirs(pass_work_dir, exist_ok=True)
210242

211243
# --- Step 3: Decomposition ---
212-
print("\n--- Phase 1: Decomposition ---", flush=True)
213-
need_decompose = True if len(target_models) > 0 else False
244+
need_decompose = (
245+
True
246+
if task_controller.task_scheduler["run_decomposer"] and len(target_models) > 0
247+
else False
248+
)
249+
if need_decompose:
250+
print("\n--- Phase 1: Decomposition ---", flush=True)
214251
while need_decompose:
215252
failed_extraction = []
216253

@@ -229,7 +266,6 @@ def main(args):
229266
rectied_model_path,
230267
model_out_dir,
231268
current_max_size,
232-
decorator_template,
233269
)
234270
if not success:
235271
failed_extraction.append(rectied_model_path)
@@ -250,18 +286,21 @@ def main(args):
250286
need_decompose = False
251287

252288
# --- Step 4: Testing ---
253-
print("\n--- Phase 2: Batch Testing ---")
254-
pass_log_path = os.path.join(pass_work_dir, "batch_test_result.log")
255-
run_evaluation(args.test_config, pass_work_dir, pass_log_path)
289+
if task_controller.task_scheduler["run_evaluation"]:
290+
print("\n--- Phase 2: Batch Testing ---")
291+
pass_log_path = os.path.join(pass_work_dir, "batch_test_result.log")
292+
run_evaluation(args.framework, args.test_config, pass_work_dir, pass_log_path)
256293

257294
# --- Step 5: Analysis ---
258-
print("\n--- Phase 3: Analysis ---")
259295
next_round_models = set()
260-
try:
261-
next_round_models = set(get_incorrect_models(args.tolerance, pass_log_path))
262-
print(f" [Result] Found {len(next_round_models)} incorrect subgraphs.")
263-
except Exception as e:
264-
print(f" [ERROR] Log analysis failed: {e}")
296+
if task_controller.task_scheduler["post_analysis"]:
297+
print("\n--- Phase 3: Analysis ---")
298+
next_round_models = set()
299+
try:
300+
next_round_models = set(get_incorrect_models(args.tolerance, pass_log_path))
301+
print(f" [Result] Found {len(next_round_models)} incorrect subgraphs.")
302+
except Exception as e:
303+
print(f" [ERROR] Log analysis failed: {e}")
265304

266305
# --- Step 6: Save State ---
267306
save_current_config(
@@ -287,16 +326,9 @@ def main(args):
287326
parser.add_argument(
288327
"--test-config", type=str, required=True, help="Base64 encoded test config"
289328
)
290-
parser.add_argument(
291-
"--decorator-config",
292-
type=str,
293-
required=True,
294-
help="Base64 encoded decorator config template",
295-
)
296329
parser.add_argument(
297330
"--tolerance", type=int, required=True, help="Tolerance level range [-10, 5)"
298331
)
299332
parser.add_argument("--max-subgraph-size", type=int, default=4096)
300-
301333
args = parser.parse_args()
302334
main(args)

graph_net/test/subgraph_decompose_and_evaluation_step_test.sh

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,31 @@ FRAMEWORK="paddle"
88
LOG_FILE="$GRAPH_NET_ROOT/test/log_xpu.txt"
99
#LOG_FILE="/work/GraphNet/benchmark_results/log_test_target_device-xpu_p800_nope-pd_20251111_2.txt"
1010
OUTPUT_DIR="outputs"
11-
TOLERANCE=3
11+
TOLERANCE=0
1212
INITIAL_MAX_SIZE=4096
1313

14-
test_config_json_str=$(cat <<EOF
14+
#rm -rf ${OUTPUT_DIR}
15+
16+
test_compiler_config_str=$(cat <<EOF
1517
{
16-
"module_name": "graph_net.${FRAMEWORK}.test_compiler",
17-
"arguments": {
18+
"test_module_name": "test_compiler",
19+
"test_compiler_arguments": {
20+
"model-path": null,
21+
"compiler": "nope",
22+
"device": "cuda",
23+
"warmup": 5,
24+
"trials": 20
25+
}
26+
}
27+
EOF
28+
)
29+
30+
test_reference_device_config_str=$(cat <<EOF
31+
{
32+
"test_module_name": "test_reference_device",
33+
"test_reference_device_arguments": {
34+
"model-path": null,
35+
"reference-dir": null,
1836
"compiler": "nope",
1937
"device": "cuda",
2038
"warmup": 5,
@@ -24,25 +42,29 @@ test_config_json_str=$(cat <<EOF
2442
EOF
2543
)
2644

27-
extractor_config_json_str=$(cat <<EOF
45+
test_target_device_config_str=$(cat <<EOF
2846
{
29-
"decorator_path": "$GRAPH_NET_ROOT/${FRAMEWORK}/extractor.py",
30-
"decorator_config": {
31-
"name": "PLACEHOLDER_NAME",
32-
"custom_extractor_path": "$GRAPH_NET_ROOT/${FRAMEWORK}/naive_graph_decomposer.py",
33-
"custom_extractor_config": {
34-
"output_dir": "PLACEHOLDER_DIR",
35-
"split_positions": [],
36-
"group_head_and_tail": true,
37-
"chain_style": false
38-
}
47+
"test_module_name": "test_target_device",
48+
"test_target_device_arguments": {
49+
"model-path": null,
50+
"reference-dir": null,
51+
"device": "xpu"
3952
}
4053
}
4154
EOF
4255
)
4356

44-
TEST_CONFIG_B64=$(echo "$test_config_json_str" | base64 -w 0)
45-
EXTRACTOR_CONFIG_B64=$(echo "$extractor_config_json_str" | base64 -w 0)
57+
test_module_name="test_reference_device"
58+
if [ "${test_module_name}" = "test_compiler" ]; then
59+
TEST_CONFIG_B64=$(echo "$test_compiler_config_str" | base64 -w 0)
60+
elif [ "${test_module_name}" = "test_reference_device" ]; then
61+
TEST_CONFIG_B64=$(echo "$test_reference_device_config_str" | base64 -w 0)
62+
elif [ "${test_module_name}" = "test_target_device" ]; then
63+
TEST_CONFIG_B64=$(echo "$test_reference_device_config_str" | base64 -w 0)
64+
else
65+
echo "test_module_name (${test_module_name}) is unsupported!"
66+
exit
67+
fi
4668

4769
echo "Starting GraphNet Auto-Debugger"
4870
echo "--------------------------------------------------------"
@@ -56,7 +78,6 @@ python3 -m graph_net.subgraph_decompose_and_evaluation_step \
5678
--output-dir="$OUTPUT_DIR" \
5779
--framework="${FRAMEWORK}" \
5880
--test-config="$TEST_CONFIG_B64" \
59-
--decorator-config="$EXTRACTOR_CONFIG_B64" \
6081
--tolerance="$TOLERANCE" \
6182
--max-subgraph-size="$INITIAL_MAX_SIZE"
6283

0 commit comments

Comments
 (0)