Skip to content

Commit b0ba9c6

Browse files
committed
Reorganize codes.
1 parent 29ab5ae commit b0ba9c6

File tree

1 file changed

+39
-25
lines changed

1 file changed

+39
-25
lines changed

graph_net/paddle/test_compiler.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,37 @@ def get_compiled_model(args, model):
9797
return compiled_model
9898

9999

100+
def count_number_of_ops(args, model, eager_mode):
101+
if eager_mode:
102+
static_model = paddle.jit.to_static(
103+
model,
104+
input_spec=get_input_spec(args),
105+
full_graph=True,
106+
backend=None,
107+
)
108+
static_model.eval()
109+
program = static_model.forward.concrete_program.main_program
110+
else:
111+
program = model.forward.concrete_program.main_program
112+
print(program)
113+
114+
num_ops = 0
115+
for block in program.blocks:
116+
for op in block.ops:
117+
if op.name() != "pd_op.data" and not op.name().startswith("builtin."):
118+
num_ops += 1
119+
print(f"Totally {num_ops} ops.")
120+
print("")
121+
return num_ops
122+
123+
100124
def measure_performance(model_call, args, synchronizer_func):
101125
stats = {}
102126

103127
# Warmup runs
128+
outs = model_call()
104129
for _ in range(args.warmup):
105-
outs = model_call()
130+
model_call()
106131
synchronizer_func()
107132

108133
hardware_name = get_hardward_name(args)
@@ -130,7 +155,7 @@ def measure_performance(model_call, args, synchronizer_func):
130155
end_event = paddle.device.Event(enable_timing=True)
131156

132157
start_event.record()
133-
outs = model_call()
158+
model_call()
134159
end_event.record()
135160

136161
gpu_time_ms = start_event.elapsed_time(end_event)
@@ -149,7 +174,7 @@ def measure_performance(model_call, args, synchronizer_func):
149174
for i in range(args.trials):
150175
duration_box = test_compiler_util.DurationBox(-1)
151176
with test_compiler_util.naive_timer(duration_box, synchronizer_func):
152-
outs = model_call()
177+
model_call()
153178
print(f"Trial {i + 1}: e2e={duration_box.value:.4f} ms")
154179
e2e_times.append(duration_box.value)
155180
stats["e2e"] = test_compiler_util.get_timing_stats(e2e_times)
@@ -205,28 +230,15 @@ def print_cmp(key, func, **kwargs):
205230
expected_out = regular_outputs(expected_out)
206231
compiled_out = regular_outputs(compiled_out)
207232

208-
print_cmp("cmp.equal", get_cmp_equal)
209-
print_cmp("cmp.all_close_atol8_rtol8", get_cmp_all_close, atol=1e-8, rtol=1e-8)
210-
print_cmp("cmp.all_close_atol8_rtol5", get_cmp_all_close, atol=1e-8, rtol=1e-5)
211-
print_cmp("cmp.all_close_atol5_rtol5", get_cmp_all_close, atol=1e-5, rtol=1e-5)
212-
print_cmp("cmp.all_close_atol3_rtol2", get_cmp_all_close, atol=1e-3, rtol=1e-2)
213-
print_cmp("cmp.all_close_atol2_rtol1", get_cmp_all_close, atol=1e-2, rtol=1e-1)
214-
print_cmp("cmp.max_diff", get_cmp_max_diff)
215-
print_cmp("cmp.mean_diff", get_cmp_mean_diff)
216-
print_cmp(
217-
"cmp.diff_count_atol8_rtol8", get_cmp_diff_count, atol=1e-8, rtol=1e-8
218-
)
219-
print_cmp(
220-
"cmp.diff_count_atol8_rtol5", get_cmp_diff_count, atol=1e-8, rtol=1e-5
221-
)
222-
print_cmp(
223-
"cmp.diff_count_atol5_rtol5", get_cmp_diff_count, atol=1e-5, rtol=1e-5
224-
)
225-
print_cmp(
226-
"cmp.diff_count_atol3_rtol2", get_cmp_diff_count, atol=1e-3, rtol=1e-2
227-
)
228-
print_cmp(
229-
"cmp.diff_count_atol2_rtol1", get_cmp_diff_count, atol=1e-2, rtol=1e-1
233+
test_compiler_util.check_correctness(
234+
args,
235+
expected_out,
236+
compiled_out,
237+
cmp_equal_func=get_cmp_equal,
238+
cmp_all_close_func=get_cmp_all_close,
239+
cmp_max_diff_func=get_cmp_max_diff,
240+
cmp_mean_diff_func=get_cmp_mean_diff,
241+
cmp_diff_count_func=get_cmp_diff_count,
230242
)
231243

232244

@@ -236,6 +248,8 @@ def test_single_model(args):
236248
model = get_model(args)
237249
model.eval()
238250

251+
num_eager_ops = count_number_of_ops(args, model, eager_mode=True)
252+
239253
test_compiler_util.print_basic_config(
240254
args, get_hardward_name(args), get_compile_framework_version(args)
241255
)

0 commit comments

Comments
 (0)