Skip to content

Commit 7c7995c

Browse files
committed
feat: add stuctured json output in test compiler
1 parent 40b5bcb commit 7c7995c

File tree

1 file changed

+107
-36
lines changed

1 file changed

+107
-36
lines changed

graph_net/torch/test_compiler.py

Lines changed: 107 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from contextlib import contextmanager
1414
import time
1515
import json
16+
import numpy as np
1617

1718
"""
1819
Acknowledgement: We introduce evaluation method in https://github.com/ScalingIntelligence/KernelBench to enhance function.
@@ -53,7 +54,6 @@ def synchronize(self):
5354
registry_backend = {
5455
"inductor": InductorBackend(),
5556
"tensorrt": TensorRTBackend(),
56-
"default": InductorBackend(),
5757
}
5858

5959

@@ -115,7 +115,7 @@ def naive_timer(duration_box, synchronizer_func):
115115

116116

117117
def time_execution_naive(
118-
model_call, synchronizer_func, num_warmup: int = 3, num_trials: int = 100
118+
model_call, synchronizer_func, num_warmup: int = 3, num_trials: int = 10
119119
):
120120
print(f"[Profiling] Using device: CPU, warm up {num_warmup}, trials {num_trials}")
121121
for _ in range(num_warmup):
@@ -137,7 +137,6 @@ def get_timing_stats_cpu(elapsed_times: list[float]):
137137
"std": float(f"{np.std(elapsed_times):.3g}"),
138138
"min": float(f"{np.min(elapsed_times):.3g}"),
139139
"max": float(f"{np.max(elapsed_times):.3g}"),
140-
"num_trials": len(elapsed_times),
141140
}
142141
return stats
143142

@@ -148,8 +147,24 @@ def test_single_model(args):
148147
model = get_model(args)
149148
compiled_model = compiler(model)
150149

151-
eager_time_ms = -1
152-
compiled_time_ms = -1
150+
eager_stats = {}
151+
compiled_stats = {}
152+
153+
result_data = {
154+
"configuration": {
155+
"model": os.path.basename(os.path.normpath(args.model_path)),
156+
"compiler": args.compiler,
157+
"device": args.device,
158+
"warmup": args.warmup,
159+
"trials": args.trials,
160+
},
161+
"correctness": {},
162+
"performance": {
163+
"eager": {},
164+
"compiled": {},
165+
"speedup": {},
166+
},
167+
}
153168

154169
eager_model_call = lambda: model(**input_dict)
155170
compiled_model_call = lambda: compiled_model(**input_dict)
@@ -166,7 +181,6 @@ def test_single_model(args):
166181
device=torch.device("cuda:0"),
167182
)
168183
eager_stats = get_timing_stats(eager_times)
169-
eager_time_ms = eager_stats["mean"]
170184

171185
compiled_times = time_execution_with_cuda_event(
172186
compiled_model_call,
@@ -175,7 +189,6 @@ def test_single_model(args):
175189
device=torch.device("cuda:0"),
176190
)
177191
compiled_stats = get_timing_stats(compiled_times)
178-
compiled_time_ms = compiled_stats["mean"]
179192
else:
180193
eager_times = time_execution_naive(
181194
eager_model_call,
@@ -184,7 +197,6 @@ def test_single_model(args):
184197
num_trials=args.trials,
185198
)
186199
eager_stats = get_timing_stats_cpu(eager_times)
187-
eager_time_ms = eager_stats["mean"]
188200

189201
compiled_times = time_execution_naive(
190202
compiled_model_call,
@@ -193,37 +205,80 @@ def test_single_model(args):
193205
num_trials=args.trials,
194206
)
195207
compiled_stats = get_timing_stats_cpu(compiled_times)
196-
compiled_time_ms = compiled_stats["mean"]
197208

198209
expected_out = eager_model_call()
199210
compiled_out = compiled_model_call()
200211

201-
def print_cmp(key, func, **kwargs):
212+
def print_and_store_cmp(key, func, **kwargs):
202213
cmp_ret = func(expected_out, compiled_out, **kwargs)
214+
result_data["correctness"][key] = cmp_ret
203215
print(
204216
f"{args.log_prompt} {key} model_path:{args.model_path} {cmp_ret}",
205217
file=sys.stderr,
206218
)
207219

208-
print_cmp("cmp.equal", get_cmp_equal)
209-
print_cmp("cmp.all_close_atol8_rtol8", get_cmp_all_close, atol=1e-8, rtol=1e-8)
210-
print_cmp("cmp.all_close_atol8_rtol5", get_cmp_all_close, atol=1e-8, rtol=1e-5)
211-
print_cmp("cmp.all_close_atol5_rtol5", get_cmp_all_close, atol=1e-5, rtol=1e-5)
212-
print_cmp("cmp.all_close_atol3_rtol2", get_cmp_all_close, atol=1e-3, rtol=1e-2)
213-
print_cmp("cmp.all_close_atol2_rtol1", get_cmp_all_close, atol=1e-2, rtol=1e-1)
214-
print_cmp("cmp.max_diff", get_cmp_max_diff)
215-
print_cmp("cmp.mean_diff", get_cmp_mean_diff)
216-
print_cmp("cmp.diff_count_atol8_rtol8", get_cmp_diff_count, atol=1e-8, rtol=1e-8)
217-
print_cmp("cmp.diff_count_atol8_rtol5", get_cmp_diff_count, atol=1e-8, rtol=1e-5)
218-
print_cmp("cmp.diff_count_atol5_rtol5", get_cmp_diff_count, atol=1e-5, rtol=1e-5)
219-
print_cmp("cmp.diff_count_atol3_rtol2", get_cmp_diff_count, atol=1e-3, rtol=1e-2)
220-
print_cmp("cmp.diff_count_atol2_rtol1", get_cmp_diff_count, atol=1e-2, rtol=1e-1)
220+
print_and_store_cmp("equal", get_cmp_equal)
221+
print_and_store_cmp(
222+
"all_close_atol8_rtol8", get_cmp_all_close, atol=1e-8, rtol=1e-8
223+
)
224+
print_and_store_cmp(
225+
"all_close_atol8_rtol5", get_cmp_all_close, atol=1e-8, rtol=1e-5
226+
)
227+
print_and_store_cmp(
228+
"all_close_atol5_rtol5", get_cmp_all_close, atol=1e-5, rtol=1e-5
229+
)
230+
print_and_store_cmp(
231+
"all_close_atol3_rtol2", get_cmp_all_close, atol=1e-3, rtol=1e-2
232+
)
233+
print_and_store_cmp(
234+
"all_close_atol2_rtol1", get_cmp_all_close, atol=1e-2, rtol=1e-1
235+
)
236+
print_and_store_cmp("max_diff", get_cmp_max_diff)
237+
print_and_store_cmp("mean_diff", get_cmp_mean_diff)
238+
print_and_store_cmp(
239+
"diff_count_atol8_rtol8", get_cmp_diff_count, atol=1e-8, rtol=1e-8
240+
)
241+
print_and_store_cmp(
242+
"diff_count_atol8_rtol5", get_cmp_diff_count, atol=1e-8, rtol=1e-5
243+
)
244+
print_and_store_cmp(
245+
"diff_count_atol5_rtol5", get_cmp_diff_count, atol=1e-5, rtol=1e-5
246+
)
247+
print_and_store_cmp(
248+
"diff_count_atol3_rtol2", get_cmp_diff_count, atol=1e-3, rtol=1e-2
249+
)
250+
print_and_store_cmp(
251+
"diff_count_atol2_rtol1", get_cmp_diff_count, atol=1e-2, rtol=1e-1
252+
)
253+
254+
eager_time_ms = eager_stats["mean"]
255+
compiled_time_ms = compiled_stats["mean"]
256+
257+
result_data["performance"]["eager"]["mean"] = eager_stats["mean"]
258+
result_data["performance"]["eager"]["std"] = eager_stats["std"]
259+
result_data["performance"]["eager"]["min"] = eager_stats["min"]
260+
result_data["performance"]["eager"]["max"] = eager_stats["max"]
261+
result_data["performance"]["compiled"]["mean"] = compiled_stats["mean"]
262+
result_data["performance"]["compiled"]["std"] = compiled_stats["std"]
263+
result_data["performance"]["compiled"]["min"] = compiled_stats["min"]
264+
result_data["performance"]["compiled"]["max"] = compiled_stats["max"]
265+
if eager_time_ms > 0 and compiled_time_ms > 0:
266+
result_data["performance"]["speedup"] = eager_time_ms / compiled_time_ms
221267

222268
print(
223269
f"{args.log_prompt} duration model_path:{args.model_path} eager:{eager_time_ms:.4f} compiled:{compiled_time_ms:.4f}",
224270
file=sys.stderr,
225271
)
226272

273+
if args.output_dir:
274+
os.makedirs(args.output_dir, exist_ok=True)
275+
model_name = result_data["configuration"]["model"]
276+
compiler_name = args.compiler
277+
file_path = os.path.join(args.output_dir, f"{model_name}_{compiler_name}.json")
278+
with open(file_path, "w") as f:
279+
json.dump(result_data, f, indent=4)
280+
print(f"Result saved to {file_path}", file=sys.stderr)
281+
227282

228283
def get_cmp_equal(expected_out, compiled_out):
229284
return " ".join(
@@ -261,18 +316,27 @@ def get_cmp_diff_count(expected_out, compiled_out, atol, rtol):
261316

262317
def test_multi_models(args):
263318
for model_path in get_recursively_model_path(args.model_path):
264-
cmd = "".join(
265-
[
266-
sys.executable,
267-
" -m graph_net.torch.test_compiler",
268-
f" --model-path {model_path}",
269-
f" --compiler {args.compiler}",
270-
f" --warmup {args.warmup}",
271-
f" --trials {args.trials}",
272-
f" --log-prompt {args.log_prompt}",
273-
f" --device {args.device}",
274-
]
275-
)
319+
cmd_list = [
320+
sys.executable,
321+
"-m",
322+
"graph_net.torch.test_compiler",
323+
"--model-path",
324+
model_path,
325+
"--compiler",
326+
args.compiler,
327+
"--warmup",
328+
str(args.warmup),
329+
"--trials",
330+
str(args.trials),
331+
"--log-prompt",
332+
args.log_prompt,
333+
"--device",
334+
args.device,
335+
]
336+
if args.output_dir:
337+
cmd_list.extend(["--output-dir", args.output_dir])
338+
339+
cmd = " ".join(cmd_list)
276340
cmd_ret = os.system(cmd)
277341
assert cmd_ret == 0, f"{cmd_ret=}, {cmd=}"
278342

@@ -318,7 +382,7 @@ def main(args):
318382
"--compiler",
319383
type=str,
320384
required=False,
321-
default="default",
385+
default="inductor",
322386
help="Path to customized compiler python file",
323387
)
324388
parser.add_argument(
@@ -341,5 +405,12 @@ def main(args):
341405
default="graph-net-test-compiler-log",
342406
help="Log prompt for performance log filtering.",
343407
)
408+
parser.add_argument(
409+
"--output-dir",
410+
type=str,
411+
required=False,
412+
default=None,
413+
help="Directory to save the structured JSON result file.",
414+
)
344415
args = parser.parse_args()
345416
main(args=args)

0 commit comments

Comments
 (0)