Skip to content

Commit 7f7366a

Browse files
committed
Run test_target_device successfully.
1 parent 19cefd5 commit 7f7366a

File tree

2 files changed

+31
-24
lines changed

2 files changed

+31
-24
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,32 @@ def convert_b64_string_to_json(b64str):
1919
class TaskController:
2020
def __init__(self, args):
2121
self.root_output_dir = os.path.abspath(args.output_dir)
22-
self.pass_id = self._determine_current_pass_id(self.root_output_dir)
2322
self.test_config = convert_b64_string_to_json(args.test_config)
2423
assert "test_module_name" in self.test_config
2524

26-
self._init_task_scheduler(self.test_config["test_module_name"])
25+
test_module_name = self.test_config["test_module_name"]
26+
max_pass_id = self._determine_max_pass_id(self.root_output_dir)
27+
self.current_pass_id = (
28+
max_pass_id if test_module_name == "test_target_device" else max_pass_id + 1
29+
)
30+
print(
31+
f"test_module_name: {test_module_name}, current_pass_id: {self.current_pass_id}"
32+
)
2733

28-
def _determine_current_pass_id(self, output_dir: str) -> int:
34+
self._init_task_scheduler(test_module_name)
35+
36+
def _determine_max_pass_id(self, output_dir: str) -> int:
2937
"""Scans the output directory to determine the next pass ID."""
3038
if not os.path.exists(output_dir):
31-
return 0
39+
return -1
3240
existing_passes = glob.glob(os.path.join(output_dir, "pass_*"))
3341
valid_ids = []
3442
for p in existing_passes:
3543
basename = os.path.basename(p)
3644
parts = basename.split("_")
3745
if len(parts) == 2 and parts[1].isdigit():
3846
valid_ids.append(int(parts[1]))
39-
return max(valid_ids) + 1 if valid_ids else 0
47+
return max(valid_ids) if valid_ids else -1
4048

4149
def _init_task_scheduler(self, test_module_name):
4250
assert test_module_name in [
@@ -177,7 +185,7 @@ def run_evaluation(
177185
test_module_name = test_config["test_module_name"]
178186
test_module_arguments = test_config[f"{test_module_name}_arguments"]
179187
test_module_arguments["model-path"] = work_dir
180-
if test_module_name == "test_reference_device":
188+
if test_module_name in ["test_reference_device", "test_target_device"]:
181189
test_module_arguments["reference-dir"] = os.path.join(
182190
work_dir, "reference_device_outputs"
183191
)
@@ -204,7 +212,7 @@ def run_evaluation(
204212
def main(args):
205213
task_controller = TaskController(args)
206214
base_output_dir = task_controller.root_output_dir
207-
current_pass_id = task_controller.pass_id
215+
current_pass_id = task_controller.current_pass_id
208216

209217
print("=" * 80)
210218
print(f" GraphNet Auto-Debugger | ROUND: PASS_{current_pass_id}")
@@ -236,9 +244,8 @@ def main(args):
236244

237245
# --- Step 2: Prepare Workspace ---
238246
pass_work_dir = os.path.join(base_output_dir, f"pass_{current_pass_id}")
239-
if os.path.exists(pass_work_dir):
240-
shutil.rmtree(pass_work_dir)
241-
os.makedirs(pass_work_dir, exist_ok=True)
247+
if not os.path.exists(pass_work_dir):
248+
os.makedirs(pass_work_dir, exist_ok=True)
242249

243250
# --- Step 3: Decomposition ---
244251
need_decompose = (
@@ -248,28 +255,29 @@ def main(args):
248255
)
249256
if need_decompose:
250257
print("\n--- Phase 1: Decomposition ---", flush=True)
258+
failed_extraction = []
251259
while need_decompose:
252-
failed_extraction = []
260+
decomposed_samples_dir = os.path.join(
261+
pass_work_dir, "samples" if args.framework == "torch" else "paddle_samples"
262+
)
263+
os.makedirs(decomposed_samples_dir, exist_ok=True)
253264

254265
for idx, model_path in enumerate(target_models):
255266
rectied_model_path = get_rectfied_model_path(model_path)
256267
assert os.path.exists(
257268
rectied_model_path
258269
), f"{rectied_model_path} does not exist."
259270

260-
model_name = get_model_name_with_subgraph_tag(rectied_model_path)
261-
model_out_dir = os.path.join(pass_work_dir, model_name)
262-
os.makedirs(model_out_dir, exist_ok=True)
263-
271+
os.makedirs(decomposed_samples_dir, exist_ok=True)
264272
success = run_decomposer(
265273
args.framework,
266274
rectied_model_path,
267-
model_out_dir,
275+
decomposed_samples_dir,
268276
current_max_size,
269277
)
270278
if not success:
271279
failed_extraction.append(rectied_model_path)
272-
num_decomposed_samples = count_samples(pass_work_dir)
280+
num_decomposed_samples = count_samples(decomposed_samples_dir)
273281
print(
274282
f"- number of graphs: {len(target_models)} -> {num_decomposed_samples}",
275283
flush=True,
@@ -279,8 +287,8 @@ def main(args):
279287

280288
if num_decomposed_samples == len(target_models):
281289
need_decompose = True
282-
shutil.rmtree(pass_work_dir)
283-
os.makedirs(pass_work_dir, exist_ok=True)
290+
shutil.rmtree(decomposed_samples_dir)
291+
os.makedirs(decomposed_samples_dir, exist_ok=True)
284292
current_max_size = max(1, current_max_size // 2)
285293
else:
286294
need_decompose = False
@@ -331,4 +339,5 @@ def main(args):
331339
)
332340
parser.add_argument("--max-subgraph-size", type=int, default=4096)
333341
args = parser.parse_args()
342+
print(args)
334343
main(args)

graph_net/test/subgraph_decompose_and_evaluation_step_test.sh

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22

3-
export PYTHONPATH=/work/GraphNet:/work/abstract_pass/Athena:$PYTHONPATH
3+
export PYTHONPATH=/root/paddlejob/workspace/env_run/liuyiqun/GraphNet:$PYTHONPATH
44

55
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(graph_net.__file__))")
66

@@ -11,8 +11,6 @@ OUTPUT_DIR="outputs"
1111
TOLERANCE=0
1212
INITIAL_MAX_SIZE=4096
1313

14-
#rm -rf ${OUTPUT_DIR}
15-
1614
test_compiler_config_str=$(cat <<EOF
1715
{
1816
"test_module_name": "test_compiler",
@@ -54,13 +52,13 @@ test_target_device_config_str=$(cat <<EOF
5452
EOF
5553
)
5654

57-
test_module_name="test_reference_device"
55+
test_module_name="test_target_device"
5856
if [ "${test_module_name}" = "test_compiler" ]; then
5957
TEST_CONFIG_B64=$(echo "$test_compiler_config_str" | base64 -w 0)
6058
elif [ "${test_module_name}" = "test_reference_device" ]; then
6159
TEST_CONFIG_B64=$(echo "$test_reference_device_config_str" | base64 -w 0)
6260
elif [ "${test_module_name}" = "test_target_device" ]; then
63-
TEST_CONFIG_B64=$(echo "$test_reference_device_config_str" | base64 -w 0)
61+
TEST_CONFIG_B64=$(echo "$test_target_device_config_str" | base64 -w 0)
6462
else
6563
echo "test_module_name (${test_module_name}) is unsupported!"
6664
exit

0 commit comments

Comments
 (0)