Skip to content

Commit d347d7c

Browse files
committed
Add test_device support for torch
1 parent 918cc50 commit d347d7c

File tree

3 files changed

+415
-2
lines changed

3 files changed

+415
-2
lines changed

graph_net/torch/test_compiler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def get_input_dict(args):
120120

121121
def measure_performance(model_call, args, compiler):
122122
stats = {}
123+
outs = model_call()
123124

124125
# Warmup runs
125126
for _ in range(args.warmup):
@@ -180,9 +181,9 @@ def measure_performance(model_call, args, compiler):
180181
flush=True,
181182
)
182183
e2e_times.append(duration_box.value)
183-
stats["e2e"] = test_compiler_utilget_timing_stats(e2e_times)
184+
stats["e2e"] = test_compiler_util.get_timing_stats(e2e_times)
184185

185-
return stats
186+
return outs, stats
186187

187188

188189
def test_single_model(args):
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import argparse
2+
import importlib.util
3+
import torch
4+
import time
5+
import numpy as np
6+
import random
7+
import os
8+
from pathlib import Path
9+
from contextlib import redirect_stdout, redirect_stderr
10+
import json
11+
import re
12+
import sys
13+
import traceback
14+
15+
from graph_net import path_utils
16+
from graph_net import test_compiler_util
17+
from graph_net.torch import test_compiler
18+
19+
20+
def get_reference_log_path(reference_dir, model_path):
21+
model_name = model_path.split("torch_samples/")[-1].replace(os.sep, "_")
22+
return os.path.join(reference_dir, f"{model_name}.log")
23+
24+
25+
def get_reference_output_path(reference_dir, model_path):
26+
model_name = model_path.split("torch_samples/")[-1].replace(os.sep, "_")
27+
return os.path.join(reference_dir, f"{model_name}.pth")
28+
29+
30+
def test_single_model(args):
31+
ref_log = get_reference_log_path(args.reference_dir, args.model_path)
32+
ref_dump = get_reference_output_path(args.reference_dir, args.model_path)
33+
print(f"Reference log path: {ref_log}", file=sys.stderr, flush=True)
34+
print(f"Reference outputs path: {ref_dump}", file=sys.stderr, flush=True)
35+
36+
with open(ref_log, "w", encoding="utf-8") as log_f:
37+
with redirect_stdout(log_f), redirect_stderr(log_f):
38+
compiler = test_compiler.get_compiler_backend(args)
39+
40+
input_dict = test_compiler.get_input_dict(args)
41+
model = test_compiler.get_model(args)
42+
model.eval()
43+
44+
test_compiler_util.print_with_log_prompt(
45+
"[Config] seed:", args.seed, args.log_prompt
46+
)
47+
48+
test_compiler_util.print_basic_config(
49+
args,
50+
test_compiler.get_hardward_name(args),
51+
test_compiler.get_compile_framework_version(args),
52+
)
53+
54+
success = False
55+
time_stats = {}
56+
try:
57+
compiled_model = compiler(model)
58+
model_call = lambda: compiled_model(**input_dict)
59+
outputs, time_stats = test_compiler.measure_performance(
60+
model_call, args, compiler
61+
)
62+
success = True
63+
except Exception as e:
64+
print(
65+
f"Run model failed: {str(e)}\n{traceback.format_exc()}",
66+
file=sys.stderr,
67+
flush=True,
68+
)
69+
70+
test_compiler_util.print_running_status(args, success)
71+
if success:
72+
torch.save(outputs, str(ref_dump))
73+
test_compiler_util.print_with_log_prompt(
74+
"[Performance][eager]:", json.dumps(time_stats), args.log_prompt
75+
)
76+
77+
with open(ref_log, "r", encoding="utf-8") as f:
78+
content = f.read()
79+
print(content, file=sys.stderr, flush=True)
80+
81+
82+
def test_multi_models(args):
83+
test_samples = test_compiler_util.get_allow_samples(args.allow_list)
84+
85+
sample_idx = 0
86+
failed_samples = []
87+
module_name = os.path.splitext(os.path.basename(__file__))[0]
88+
for model_path in path_utils.get_recursively_model_path(args.model_path):
89+
if test_samples is None or os.path.abspath(model_path) in test_samples:
90+
print(
91+
f"[{sample_idx}] {module_name}, model_path: {model_path}",
92+
file=sys.stderr,
93+
flush=True,
94+
)
95+
cmd = " ".join(
96+
[
97+
sys.executable,
98+
f"-m graph_net.torch.{module_name}",
99+
f"--model-path {model_path}",
100+
f"--compiler {args.compiler}",
101+
f"--device {args.device}",
102+
f"--warmup {args.warmup}",
103+
f"--trials {args.trials}",
104+
f"--log-prompt {args.log_prompt}",
105+
f"--seed {args.seed}",
106+
f"--reference-dir {args.reference_dir}",
107+
]
108+
)
109+
cmd_ret = os.system(cmd)
110+
# assert cmd_ret == 0, f"{cmd_ret=}, {cmd=}"
111+
if cmd_ret != 0:
112+
failed_samples.append(model_path)
113+
sample_idx += 1
114+
115+
print(
116+
f"Totally {sample_idx} verified samples, failed {len(failed_samples)} samples.",
117+
file=sys.stderr,
118+
flush=True,
119+
)
120+
for model_path in failed_samples:
121+
print(f"- {model_path}", file=sys.stderr, flush=True)
122+
123+
124+
def main(args):
125+
assert os.path.isdir(args.model_path)
126+
# Support all torch compilers
127+
valid_compilers = list(test_compiler.registry_backend.keys())
128+
assert (
129+
args.compiler in valid_compilers
130+
), f"Compiler must be one of {valid_compilers}"
131+
assert args.device in ["cuda", "cpu", "xpu"]
132+
133+
test_compiler.set_seed(random_seed=args.seed)
134+
135+
ref_dump_dir = Path(args.reference_dir)
136+
ref_dump_dir.mkdir(parents=True, exist_ok=True)
137+
138+
if path_utils.is_single_model_dir(args.model_path):
139+
test_single_model(args)
140+
else:
141+
test_multi_models(args)
142+
143+
144+
if __name__ == "__main__":
145+
parser = argparse.ArgumentParser(description="Test compiler performance.")
146+
parser.add_argument(
147+
"--model-path",
148+
type=str,
149+
required=True,
150+
help="Path to model file(s), each subdirectory containing graph_net.json will be regarded as a model",
151+
)
152+
parser.add_argument(
153+
"--compiler",
154+
type=str,
155+
required=False,
156+
default="inductor",
157+
help="Compiler backend to use",
158+
)
159+
parser.add_argument(
160+
"--device",
161+
type=str,
162+
required=False,
163+
default="cuda",
164+
help="Device for testing the compiler (e.g., 'cpu' or 'cuda')",
165+
)
166+
parser.add_argument(
167+
"--warmup", type=int, required=False, default=5, help="Number of warmup steps"
168+
)
169+
parser.add_argument(
170+
"--trials", type=int, required=False, default=10, help="Number of timing trials"
171+
)
172+
parser.add_argument(
173+
"--log-prompt",
174+
type=str,
175+
required=False,
176+
default="graph-net-test-device-log",
177+
help="Log prompt for performance log filtering.",
178+
)
179+
parser.add_argument(
180+
"--allow-list",
181+
type=str,
182+
required=False,
183+
default=None,
184+
help="Path to samples list, each line contains a sample path",
185+
)
186+
parser.add_argument(
187+
"--seed",
188+
type=int,
189+
required=False,
190+
default=123,
191+
help="Random seed (default: 123)",
192+
)
193+
parser.add_argument(
194+
"--reference-dir",
195+
type=str,
196+
required=True,
197+
help="Directory to save reference stats log and outputs",
198+
)
199+
args = parser.parse_args()
200+
main(args=args)

0 commit comments

Comments
 (0)