@@ -97,12 +97,37 @@ def get_compiled_model(args, model):
9797 return compiled_model
9898
9999
100+ def count_number_of_ops (args , model , eager_mode ):
101+ if eager_mode :
102+ static_model = paddle .jit .to_static (
103+ model ,
104+ input_spec = get_input_spec (args ),
105+ full_graph = True ,
106+ backend = None ,
107+ )
108+ static_model .eval ()
109+ program = static_model .forward .concrete_program .main_program
110+ else :
111+ program = model .forward .concrete_program .main_program
112+ print (program )
113+
114+ num_ops = 0
115+ for block in program .blocks :
116+ for op in block .ops :
117+ if op .name () != "pd_op.data" and not op .name ().startswith ("builtin." ):
118+ num_ops += 1
119+ print (f"Totally { num_ops } ops." )
120+ print ("" )
121+ return num_ops
122+
123+
100124def measure_performance (model_call , args , synchronizer_func ):
101125 stats = {}
102126
103127 # Warmup runs
128+ outs = model_call ()
104129 for _ in range (args .warmup ):
105- outs = model_call ()
130+ model_call ()
106131 synchronizer_func ()
107132
108133 hardware_name = get_hardward_name (args )
@@ -130,7 +155,7 @@ def measure_performance(model_call, args, synchronizer_func):
130155 end_event = paddle .device .Event (enable_timing = True )
131156
132157 start_event .record ()
133- outs = model_call ()
158+ model_call ()
134159 end_event .record ()
135160
136161 gpu_time_ms = start_event .elapsed_time (end_event )
@@ -149,7 +174,7 @@ def measure_performance(model_call, args, synchronizer_func):
149174 for i in range (args .trials ):
150175 duration_box = test_compiler_util .DurationBox (- 1 )
151176 with test_compiler_util .naive_timer (duration_box , synchronizer_func ):
152- outs = model_call ()
177+ model_call ()
153178 print (f"Trial { i + 1 } : e2e={ duration_box .value :.4f} ms" )
154179 e2e_times .append (duration_box .value )
155180 stats ["e2e" ] = test_compiler_util .get_timing_stats (e2e_times )
@@ -205,28 +230,15 @@ def print_cmp(key, func, **kwargs):
205230 expected_out = regular_outputs (expected_out )
206231 compiled_out = regular_outputs (compiled_out )
207232
208- print_cmp ("cmp.equal" , get_cmp_equal )
209- print_cmp ("cmp.all_close_atol8_rtol8" , get_cmp_all_close , atol = 1e-8 , rtol = 1e-8 )
210- print_cmp ("cmp.all_close_atol8_rtol5" , get_cmp_all_close , atol = 1e-8 , rtol = 1e-5 )
211- print_cmp ("cmp.all_close_atol5_rtol5" , get_cmp_all_close , atol = 1e-5 , rtol = 1e-5 )
212- print_cmp ("cmp.all_close_atol3_rtol2" , get_cmp_all_close , atol = 1e-3 , rtol = 1e-2 )
213- print_cmp ("cmp.all_close_atol2_rtol1" , get_cmp_all_close , atol = 1e-2 , rtol = 1e-1 )
214- print_cmp ("cmp.max_diff" , get_cmp_max_diff )
215- print_cmp ("cmp.mean_diff" , get_cmp_mean_diff )
216- print_cmp (
217- "cmp.diff_count_atol8_rtol8" , get_cmp_diff_count , atol = 1e-8 , rtol = 1e-8
218- )
219- print_cmp (
220- "cmp.diff_count_atol8_rtol5" , get_cmp_diff_count , atol = 1e-8 , rtol = 1e-5
221- )
222- print_cmp (
223- "cmp.diff_count_atol5_rtol5" , get_cmp_diff_count , atol = 1e-5 , rtol = 1e-5
224- )
225- print_cmp (
226- "cmp.diff_count_atol3_rtol2" , get_cmp_diff_count , atol = 1e-3 , rtol = 1e-2
227- )
228- print_cmp (
229- "cmp.diff_count_atol2_rtol1" , get_cmp_diff_count , atol = 1e-2 , rtol = 1e-1
233+ test_compiler_util .check_correctness (
234+ args ,
235+ expected_out ,
236+ compiled_out ,
237+ cmp_equal_func = get_cmp_equal ,
238+ cmp_all_close_func = get_cmp_all_close ,
239+ cmp_max_diff_func = get_cmp_max_diff ,
240+ cmp_mean_diff_func = get_cmp_mean_diff ,
241+ cmp_diff_count_func = get_cmp_diff_count ,
230242 )
231243
232244
@@ -236,6 +248,8 @@ def test_single_model(args):
236248 model = get_model (args )
237249 model .eval ()
238250
251+ num_eager_ops = count_number_of_ops (args , model , eager_mode = True )
252+
239253 test_compiler_util .print_basic_config (
240254 args , get_hardward_name (args ), get_compile_framework_version (args )
241255 )
0 commit comments