Skip to content

Commit 413742e

Browse files
committed
Polish the timing codes of test_compiler of paddle.
1 parent eb6948b commit 413742e

File tree

2 files changed

+173
-53
lines changed

2 files changed

+173
-53
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ repos:
33
rev: 23.1.0
44
hooks:
55
- id: black
6-
language_version: python3
6+
language_version: python

graph_net/paddle/test_compiler.py

Lines changed: 172 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
from . import utils
21
import argparse
32
import importlib.util
4-
import inspect
53
import paddle
64
from pathlib import Path
7-
from typing import Type, Any
85
import sys
96
import os
10-
import os.path
117
from dataclasses import dataclass
128
from contextlib import contextmanager
139
import time
1410
import numpy as np
11+
import random
12+
13+
from . import utils
1514

1615

1716
def load_class_from_file(file_path: str, class_name: str):
@@ -33,7 +32,6 @@ def load_class_from_file(file_path: str, class_name: str):
3332

3433

3534
def get_synchronizer_func(args):
36-
assert args.compiler == "default"
3735
return paddle.device.synchronize
3836

3937

@@ -49,37 +47,46 @@ def get_input_dict(args):
4947
params = inputs_params["weight_info"]
5048
inputs = inputs_params["input_info"]
5149

52-
params.update(inputs)
53-
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
54-
return state_dict
55-
56-
57-
@dataclass
58-
class DurationBox:
59-
value: int
50+
param_dtypes = set()
51+
for name, info in params.items():
52+
dtype = str(info["info"]["dtype"])
53+
if dtype not in param_dtypes:
54+
param_dtypes.add(dtype)
6055

56+
input_dtypes = set()
57+
for name, info in inputs.items():
58+
dtype = str(info["info"]["dtype"])
59+
if dtype not in input_dtypes:
60+
input_dtypes.add(dtype)
6161

62-
@contextmanager
63-
def naive_timer(duration_box, get_synchronizer_func):
64-
get_synchronizer_func()
65-
start = time.time()
66-
yield
67-
get_synchronizer_func()
68-
end = time.time()
69-
duration_box.value = end - start
62+
params.update(inputs)
63+
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
64+
return state_dict, list(input_dtypes), list(param_dtypes)
7065

7166

7267
def get_input_spec(args):
7368
inputs_params_list = utils.load_converted_list_from_text(f"{args.model_path}")
7469
input_spec = [None] * len(inputs_params_list)
7570
for i, v in enumerate(inputs_params_list):
76-
name = v["name"]
7771
dtype = v["info"]["dtype"]
7872
shape = v["info"]["shape"]
7973
input_spec[i] = paddle.static.InputSpec(shape, dtype)
8074
return input_spec
8175

8276

77+
def get_compiled_model(args, model):
78+
input_spec = get_input_spec(args)
79+
build_strategy = paddle.static.BuildStrategy()
80+
compiled_model = paddle.jit.to_static(
81+
model,
82+
input_spec=input_spec,
83+
build_strategy=build_strategy,
84+
full_graph=True,
85+
)
86+
compiled_model.eval()
87+
return compiled_model
88+
89+
8390
def regular_item(item):
8491
if isinstance(item, paddle.Tensor) and (item.dtype == paddle.bfloat16):
8592
item = np.array(item.astype("float32"))
@@ -90,37 +97,129 @@ def regular_item(item):
9097
return item
9198

9299

100+
def count_number_of_ops(args, model):
101+
static_model = paddle.jit.to_static(
102+
model,
103+
input_spec=get_input_spec(args),
104+
full_graph=True,
105+
backend=None,
106+
)
107+
static_model.eval()
108+
program = model.forward.concrete_program.main_program
109+
# print(program)
110+
111+
num_ops = 0
112+
for block in program.blocks:
113+
for op in block.ops:
114+
if op.name() != "pd_op.data" and not op.name().startswith("builtin."):
115+
num_ops += 1
116+
print(f"Totally {num_ops} ops.")
117+
print("")
118+
return num_ops
119+
120+
121+
@dataclass
122+
class DurationBox:
123+
value: int
124+
125+
126+
@contextmanager
127+
def naive_timer(duration_box, synchronizer_func):
128+
synchronizer_func()
129+
start = time.time()
130+
yield
131+
synchronizer_func()
132+
end = time.time()
133+
duration_box.value = end - start
134+
135+
136+
def time_execution_with_cuda_event(
137+
model_call, synchronizer_func, num_warmup=3, num_trials=10, profile=False
138+
):
139+
outs = None
140+
141+
# warmups
142+
for _ in range(num_warmup):
143+
outs = model_call()
144+
synchronizer_func()
145+
146+
elapsed_times = []
147+
if profile:
148+
paddle.base.core.nvprof_start()
149+
150+
# actual trials
151+
for trial in range(num_trials):
152+
# create event marker default is not interprocess
153+
start_event = paddle.device.Event(enable_timing=True)
154+
end_event = paddle.device.Event(enable_timing=True)
155+
156+
start_event.record()
157+
outs = model_call()
158+
end_event.record()
159+
synchronizer_func()
160+
161+
# Calculate the elapsed time in milliseconds
162+
elapsed_time_ms = start_event.elapsed_time(end_event)
163+
elapsed_times.append(elapsed_time_ms)
164+
if profile:
165+
paddle.base.core.nvprof_stop()
166+
elapsed_times = elapsed_times[num_trials // 2 :]
167+
return outs, np.mean(elapsed_times)
168+
169+
170+
def time_execution_naive(model_call, synchronizer_func, num_warmup=3, num_trials=10):
171+
outs = None
172+
173+
# warmups
174+
for _ in range(num_warmup):
175+
outs = model_call()
176+
177+
# actual trials
178+
duration_box = DurationBox(-1)
179+
with naive_timer(duration_box, synchronizer_func):
180+
for i in range(num_trials):
181+
outs = model_call()
182+
return outs, duration_box.value * 1000 / float(num_trials)
183+
184+
185+
def measure_performance(model_call, synchronizer_func, args, profile=False):
186+
if not args.use_naive_timer:
187+
outs, times = time_execution_with_cuda_event(
188+
model_call,
189+
synchronizer_func=synchronizer_func,
190+
num_warmup=args.warmup,
191+
num_trials=args.trials,
192+
profile=profile,
193+
)
194+
else:
195+
outs, times = time_execution_naive(
196+
model_call,
197+
synchronizer_func=synchronizer_func,
198+
num_warmup=args.warmup,
199+
num_trials=args.trials,
200+
)
201+
return outs, times
202+
203+
93204
def test_single_model(args):
94205
synchronizer_func = get_synchronizer_func(args)
95-
input_dict = get_input_dict(args)
96-
model_dy = get_model(args)
97-
98-
# eager
99-
print("-- Run with eager mode")
100-
model_dy.eval()
101-
for _ in range(args.warmup if args.warmup > 0 else 0):
102-
model_dy(**input_dict)
103-
eager_duration_box = DurationBox(-1)
104-
with naive_timer(eager_duration_box, synchronizer_func):
105-
expected_out = model_dy(**input_dict)
106-
107-
# compiled
108-
print("-- Run with compiled mode")
109-
input_spec = get_input_spec(args)
110-
build_strategy = paddle.static.BuildStrategy()
111-
# build_strategy.build_cinn_pass = True
112-
compiled_model = paddle.jit.to_static(
113-
model_dy,
114-
input_spec=input_spec,
115-
build_strategy=build_strategy,
116-
full_graph=True,
206+
input_dict, input_dtypes, param_dtypes = get_input_dict(args)
207+
model = get_model(args)
208+
model.eval()
209+
210+
# Collect model information
211+
num_ops = count_number_of_ops(args, model)
212+
213+
print("Run on eager mode")
214+
expected_out, eager_time_ms = measure_performance(
215+
lambda: model(**input_dict), synchronizer_func, args, profile=False
216+
)
217+
218+
print("Run on compiling mode")
219+
compiled_model = get_compiled_model(args, model)
220+
compiled_out, compiled_time_ms = measure_performance(
221+
lambda: compiled_model(**input_dict), synchronizer_func, args, profile=False
117222
)
118-
compiled_model.eval()
119-
for _ in range(args.warmup if args.warmup > 0 else 0):
120-
compiled_model(**input_dict)
121-
compiled_duration_box = DurationBox(-1)
122-
with naive_timer(compiled_duration_box, synchronizer_func):
123-
compiled_out = compiled_model(**input_dict)
124223

125224
if isinstance(expected_out, paddle.Tensor):
126225
expected_out = [expected_out]
@@ -164,7 +263,11 @@ def print_cmp(key, func, **kwargs):
164263
print_cmp("cmp.diff_count_atol2_rtol1", get_cmp_diff_count, atol=1e-2, rtol=1e-1)
165264

166265
print(
167-
f"{args.log_prompt} duration model_path:{args.model_path} eager:{eager_duration_box.value} compiled:{compiled_duration_box.value}",
266+
f"{args.log_prompt} information model_path:{args.model_path} {num_ops} ops, param_dtypes:{param_dtypes}, input_dtypes:{input_dtypes}",
267+
file=sys.stderr,
268+
)
269+
print(
270+
f"{args.log_prompt} duration model_path:{args.model_path} eager:{eager_time_ms:.5f} ms, compiled:{compiled_time_ms:.5f} ms, speedup:{eager_time_ms / compiled_time_ms:.3f}",
168271
file=sys.stderr,
169272
)
170273

@@ -210,6 +313,7 @@ def test_multi_models(args):
210313
f"--model-path {model_path}",
211314
f"--compiler {args.compiler}",
212315
f"--warmup {args.warmup}",
316+
f"--trials {args.trials}",
213317
f"--log-prompt {args.log_prompt}",
214318
]
215319
)
@@ -240,6 +344,13 @@ def is_single_model_dir(model_dir):
240344

241345
def main(args):
242346
assert os.path.isdir(args.model_path)
347+
assert args.compiler == "CINN"
348+
349+
random_seed = 123
350+
paddle.seed(random_seed)
351+
random.seed(random_seed)
352+
np.random.seed(random_seed)
353+
243354
if is_single_model_dir(args.model_path):
244355
test_single_model(args)
245356
else:
@@ -258,12 +369,21 @@ def main(args):
258369
"--compiler",
259370
type=str,
260371
required=False,
261-
default="default",
372+
default="CINN",
262373
help="Path to customized compiler python file",
263374
)
264375
parser.add_argument(
265376
"--warmup", type=int, required=False, default=5, help="Number of warmup steps"
266377
)
378+
parser.add_argument(
379+
"--trials", type=int, required=False, default=10, help="Number of timing trials"
380+
)
381+
parser.add_argument(
382+
"--use-naive-timer",
383+
action="store_true",
384+
default=False,
385+
help="Use naive timer for permance measuring.",
386+
)
267387
parser.add_argument(
268388
"--log-prompt",
269389
type=str,

0 commit comments

Comments
 (0)