Skip to content

Commit 6952bc8

Browse files
committed
Define a unified BenchmarkResult and support write to json.
1 parent 51f61fc commit 6952bc8

File tree

2 files changed

+122
-3
lines changed

2 files changed

+122
-3
lines changed

graph_net/benchmark_result.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import os
2+
import sys
3+
import json
4+
import re
5+
6+
7+
class BenchmarkResult:
8+
def __init__(self, args, hardware, compile_framework_version):
9+
self.configuration = {
10+
"model_name": self.get_model_name(args),
11+
"device": args.device,
12+
"hardware": hardware,
13+
"compiler": args.compiler,
14+
"compile_framework_version": compile_framework_version,
15+
"warmup": args.warmup,
16+
"trials": args.trials,
17+
}
18+
self.model_info = {
19+
"num_ops": -1,
20+
"input_dtypes": None,
21+
"param_dtypes": None,
22+
}
23+
self.correctness = {}
24+
self.performance = {
25+
"eager": None,
26+
"compiled": None,
27+
"speedup": None,
28+
}
29+
30+
def get_model_name(self, args):
31+
fields = args.model_path.split(os.sep)
32+
33+
pattern = rf"^subgraph(_\d+)?$"
34+
if re.match(pattern, fields[-1]):
35+
model_name = f"{fields[-2]}_{fields[-1]}"
36+
else:
37+
model_name = fields[-1]
38+
return model_name
39+
40+
def update_model_info(self, num_ops, input_dtypes, param_dtypes):
41+
self.model_info["num_ops"] = num_ops
42+
self.model_info["input_dtypes"] = input_dtypes
43+
self.model_info["param_dtypes"] = param_dtypes
44+
45+
def update_corrrectness(self, key, cmp_ret):
46+
self.correctness[key] = cmp_ret
47+
48+
def update_performance(self, eager_time_ms, compiled_time_ms):
49+
self.performance["eager"] = eager_time_ms
50+
self.performance["compiled"] = compiled_time_ms
51+
if eager_time_ms > 0 and compiled_time_ms > 0:
52+
self.performance["speedup"] = eager_time_ms / compiled_time_ms
53+
return self.performance["speedup"]
54+
55+
def write_to_json(self, output_dir):
56+
assert output_dir is not None
57+
os.makedirs(output_dir, exist_ok=True)
58+
result_data = {
59+
"configuration": self.configuration,
60+
"model_info": self.model_info,
61+
"correctness": self.correctness,
62+
"performance": {
63+
k: float(f"{v:.6f}") if isinstance(v, float) else v
64+
for k, v in self.performance.items()
65+
},
66+
}
67+
model_name = self.configuration["model_name"]
68+
compiler_name = self.configuration["compiler"]
69+
file_path = os.path.join(output_dir, f"{model_name}_{compiler_name}.json")
70+
with open(file_path, "w") as f:
71+
json.dump(result_data, f, indent=4)
72+
print(f"Result saved to {file_path}", file=sys.stderr)
73+
print(result_data)

graph_net/paddle/test_compiler.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import time
1010
import numpy as np
1111
import random
12+
import platform
1213

13-
from . import utils
14+
from graph_net.paddle import utils
15+
from graph_net.benchmark_result import BenchmarkResult
1416

1517

1618
def load_class_from_file(file_path: str, class_name: str):
@@ -201,6 +203,27 @@ def measure_performance(model_call, synchronizer_func, args, profile=False):
201203
return outs, times
202204

203205

206+
def init_benchmark_result(args):
207+
if args.device == "cuda":
208+
hardware = paddle.device.cuda.get_device_name(0)
209+
elif args.device == "cpu":
210+
hardware = platform.processor()
211+
else:
212+
hardware = "unknown"
213+
214+
if args.compiler == "CINN":
215+
compile_framework_version = paddle.__version__
216+
else:
217+
compile_framework_version = "unknown"
218+
219+
result_data = BenchmarkResult(
220+
args=args,
221+
hardware=hardware,
222+
compile_framework_version=compile_framework_version,
223+
)
224+
return result_data
225+
226+
204227
def test_single_model(args):
205228
synchronizer_func = get_synchronizer_func(args)
206229
input_dict, input_dtypes, param_dtypes = get_input_dict(args)
@@ -210,12 +233,16 @@ def test_single_model(args):
210233
# Collect model information
211234
num_ops = count_number_of_ops(args, model)
212235

213-
print("Run on eager mode")
236+
# Initialize benchmark result
237+
result_data = init_benchmark_result(args)
238+
result_data.update_model_info(num_ops, input_dtypes, param_dtypes)
239+
240+
# Run on eager mode
214241
expected_out, eager_time_ms = measure_performance(
215242
lambda: model(**input_dict), synchronizer_func, args, profile=False
216243
)
217244

218-
print("Run on compiling mode")
245+
# Run on compiling mode
219246
compiled_model = get_compiled_model(args, model)
220247
compiled_out, compiled_time_ms = measure_performance(
221248
lambda: compiled_model(**input_dict), synchronizer_func, args, profile=False
@@ -243,6 +270,7 @@ def test_single_model(args):
243270

244271
def print_cmp(key, func, **kwargs):
245272
cmp_ret = func(expected_out, compiled_out, **kwargs)
273+
result_data.update_corrrectness(key, cmp_ret)
246274
print(
247275
f"{args.log_prompt} {key} model_path:{args.model_path} {cmp_ret}",
248276
file=sys.stderr,
@@ -271,6 +299,10 @@ def print_cmp(key, func, **kwargs):
271299
file=sys.stderr,
272300
)
273301

302+
result_data.update_performance(eager_time_ms, compiled_time_ms)
303+
if args.output_dir:
304+
result_data.write_to_json(args.output_dir)
305+
274306

275307
def get_cmp_equal(expected_out, compiled_out):
276308
return " ".join(
@@ -372,6 +404,13 @@ def main(args):
372404
default="CINN",
373405
help="Path to customized compiler python file",
374406
)
407+
parser.add_argument(
408+
"--device",
409+
type=str,
410+
required=False,
411+
default="cuda",
412+
help="Device for testing the compiler (e.g., 'cpu' or 'cuda')",
413+
)
375414
parser.add_argument(
376415
"--warmup", type=int, required=False, default=5, help="Number of warmup steps"
377416
)
@@ -391,5 +430,12 @@ def main(args):
391430
default="graph-net-test-compiler-log",
392431
help="Log prompt for performance log filtering.",
393432
)
433+
parser.add_argument(
434+
"--output-dir",
435+
type=str,
436+
required=False,
437+
default=None,
438+
help="Directory to save the structured JSON result file.",
439+
)
394440
args = parser.parse_args()
395441
main(args=args)

0 commit comments

Comments
 (0)