1515from graph_net .paddle import utils
1616from graph_net import path_utils
1717from graph_net import test_compiler_util
18- from graph_net .benchmark_result import BenchmarkResult
18+
19+
20+ def get_hardward_name (args ):
21+ if args .device == "cuda" :
22+ hardware = paddle .device .cuda .get_device_name (0 )
23+ elif args .device == "cpu" :
24+ hardware = platform .processor ()
25+ else :
26+ hardware = "unknown"
27+ return hardware
28+
29+
30+ def get_compile_framework_version (args ):
31+ if args .compiler == "cinn" :
32+ compile_framework_version = paddle .__version__
33+ else :
34+ compile_framework_version = "unknown"
35+ return compile_framework_version
1936
2037
2138def load_class_from_file (file_path : str , class_name : str ):
@@ -88,15 +105,18 @@ def measure_performance(model_call, args, synchronizer_func):
88105 outs = model_call ()
89106 synchronizer_func ()
90107
108+ hardware_name = get_hardward_name (args )
109+ print (
110+ f"[Profiling] Using device: { args .device } { hardware_name } , warm up { args .warmup } , trials { args .trials } " ,
111+ file = sys .stderr ,
112+ flush = True ,
113+ )
114+
91115 if "cuda" in args .device :
92116 """
93117 Acknowledgement: We evaluate the performance on both end-to-end and GPU-only timings,
94118 With reference to methods only based on CUDA events from KernelBench in https://github.com/ScalingIntelligence/KernelBench
95119 """
96- hardware_name = paddle .device .cuda .get_device_name (0 )
97- print (
98- f"{ args .log_prompt } [Profiling] Using device: { args .device } { hardware_name } , warm up { args .warmup } , trials { args .trials } "
99- )
100120
101121 e2e_times = []
102122 gpu_times = []
@@ -117,17 +137,14 @@ def measure_performance(model_call, args, synchronizer_func):
117137 e2e_times .append (duration_box .value )
118138 gpu_times .append (gpu_time_ms )
119139 print (
120- f"Trial { i + 1 } : e2e={ duration_box .value :.5f} ms, gpu={ gpu_time_ms :.5f} ms"
140+ f"Trial { i + 1 } : e2e={ duration_box .value :.5f} ms, gpu={ gpu_time_ms :.5f} ms" ,
141+ file = sys .stderr ,
142+ flush = True ,
121143 )
122144
123145 stats ["e2e" ] = test_compiler_util .get_timing_stats (e2e_times )
124146 stats ["gpu" ] = test_compiler_util .get_timing_stats (gpu_times )
125147 else : # CPU or other devices
126- hardware_name = platform .processor ()
127- print (
128- f"[Profiling] Using device: { args .device } { hardware_name } , warm up { args .warmup } , trials { args .trials } "
129- )
130-
131148 e2e_times = []
132149 for i in range (args .trials ):
133150 duration_box = test_compiler_util .DurationBox (- 1 )
@@ -140,50 +157,27 @@ def measure_performance(model_call, args, synchronizer_func):
140157 return outs , stats
141158
142159
143- def init_benchmark_result (args ):
144- if args .device == "cuda" :
145- hardware = paddle .device .cuda .get_device_name (0 )
146- elif args .device == "cpu" :
147- hardware = platform .processor ()
148- else :
149- hardware = "unknown"
150-
151- if args .compiler == "cinn" :
152- compile_framework_version = paddle .__version__
153- else :
154- compile_framework_version = "unknown"
155-
156- result_data = BenchmarkResult (
157- args = args ,
158- framework = "PaddlePaddle" ,
159- hardware = hardware ,
160- compile_framework_version = compile_framework_version ,
161- )
162- return result_data
163-
164-
165160def check_outputs (args , expected_out , compiled_out ):
166161 if isinstance (expected_out , paddle .Tensor ):
167162 expected_out = [expected_out ]
168163 if isinstance (compiled_out , paddle .Tensor ):
169164 compiled_out = [compiled_out ]
170165
171- eager_output_dtypes = [None ] * len (expected_out )
166+ eager_dtypes = [None ] * len (expected_out )
172167 for i , tensor in enumerate (expected_out ):
173- if tensor is not None :
174- eager_output_dtypes [i ] = str (tensor .dtype )
168+ eager_dtypes [i ] = (
169+ str (tensor .dtype ).replace ("paddle." , "" ) if tensor is not None else "none"
170+ )
175171
176- compiled_output_dtypes = [None ] * len (compiled_out )
172+ compiled_dtypes = [None ] * len (compiled_out )
177173 for i , tensor in enumerate (compiled_out ):
178- if tensor is not None :
179- compiled_output_dtypes [i ] = str (tensor .dtype )
174+ compiled_dtypes [i ] = (
175+ str (tensor .dtype ).replace ("paddle." , "" ) if tensor is not None else "none"
176+ )
180177
181- is_output_consistent = len (expected_out ) == len (compiled_out )
182- for a , b in zip (expected_out , compiled_out ):
183- if (a is None and b is not None ) or (a is not None and b is None ):
184- is_output_consistent = False
185- if a is not None and b is not None and a .dtype != b .dtype :
186- is_output_consistent = False
178+ type_match = test_compiler_util .check_output_datatype (
179+ args , eager_dtypes , compiled_dtypes
180+ )
187181
188182 def regular_outputs (origin_outputs ):
189183 outputs = []
@@ -197,9 +191,6 @@ def regular_outputs(origin_outputs):
197191 outputs .append (item )
198192 return outputs
199193
200- expected_out = regular_outputs (expected_out )
201- compiled_out = regular_outputs (compiled_out )
202-
203194 def print_cmp (key , func , ** kwargs ):
204195 try :
205196 cmp_ret = func (expected_out , compiled_out , ** kwargs )
@@ -210,23 +201,33 @@ def print_cmp(key, func, **kwargs):
210201 file = sys .stderr ,
211202 )
212203
213- print (
214- f"{ args .log_prompt } output_dtypes model_path:{ args .model_path } eager:{ eager_output_dtypes } compiled:{ compiled_output_dtypes } " ,
215- file = sys .stderr ,
216- )
217- print_cmp ("cmp.equal" , get_cmp_equal )
218- print_cmp ("cmp.all_close_atol8_rtol8" , get_cmp_all_close , atol = 1e-8 , rtol = 1e-8 )
219- print_cmp ("cmp.all_close_atol8_rtol5" , get_cmp_all_close , atol = 1e-8 , rtol = 1e-5 )
220- print_cmp ("cmp.all_close_atol5_rtol5" , get_cmp_all_close , atol = 1e-5 , rtol = 1e-5 )
221- print_cmp ("cmp.all_close_atol3_rtol2" , get_cmp_all_close , atol = 1e-3 , rtol = 1e-2 )
222- print_cmp ("cmp.all_close_atol2_rtol1" , get_cmp_all_close , atol = 1e-2 , rtol = 1e-1 )
223- print_cmp ("cmp.max_diff" , get_cmp_max_diff )
224- print_cmp ("cmp.mean_diff" , get_cmp_mean_diff )
225- print_cmp ("cmp.diff_count_atol8_rtol8" , get_cmp_diff_count , atol = 1e-8 , rtol = 1e-8 )
226- print_cmp ("cmp.diff_count_atol8_rtol5" , get_cmp_diff_count , atol = 1e-8 , rtol = 1e-5 )
227- print_cmp ("cmp.diff_count_atol5_rtol5" , get_cmp_diff_count , atol = 1e-5 , rtol = 1e-5 )
228- print_cmp ("cmp.diff_count_atol3_rtol2" , get_cmp_diff_count , atol = 1e-3 , rtol = 1e-2 )
229- print_cmp ("cmp.diff_count_atol2_rtol1" , get_cmp_diff_count , atol = 1e-2 , rtol = 1e-1 )
204+ if type_match :
205+ expected_out = regular_outputs (expected_out )
206+ compiled_out = regular_outputs (compiled_out )
207+
208+ print_cmp ("cmp.equal" , get_cmp_equal )
209+ print_cmp ("cmp.all_close_atol8_rtol8" , get_cmp_all_close , atol = 1e-8 , rtol = 1e-8 )
210+ print_cmp ("cmp.all_close_atol8_rtol5" , get_cmp_all_close , atol = 1e-8 , rtol = 1e-5 )
211+ print_cmp ("cmp.all_close_atol5_rtol5" , get_cmp_all_close , atol = 1e-5 , rtol = 1e-5 )
212+ print_cmp ("cmp.all_close_atol3_rtol2" , get_cmp_all_close , atol = 1e-3 , rtol = 1e-2 )
213+ print_cmp ("cmp.all_close_atol2_rtol1" , get_cmp_all_close , atol = 1e-2 , rtol = 1e-1 )
214+ print_cmp ("cmp.max_diff" , get_cmp_max_diff )
215+ print_cmp ("cmp.mean_diff" , get_cmp_mean_diff )
216+ print_cmp (
217+ "cmp.diff_count_atol8_rtol8" , get_cmp_diff_count , atol = 1e-8 , rtol = 1e-8
218+ )
219+ print_cmp (
220+ "cmp.diff_count_atol8_rtol5" , get_cmp_diff_count , atol = 1e-8 , rtol = 1e-5
221+ )
222+ print_cmp (
223+ "cmp.diff_count_atol5_rtol5" , get_cmp_diff_count , atol = 1e-5 , rtol = 1e-5
224+ )
225+ print_cmp (
226+ "cmp.diff_count_atol3_rtol2" , get_cmp_diff_count , atol = 1e-3 , rtol = 1e-2
227+ )
228+ print_cmp (
229+ "cmp.diff_count_atol2_rtol1" , get_cmp_diff_count , atol = 1e-2 , rtol = 1e-1
230+ )
230231
231232
232233def test_single_model (args ):
@@ -235,6 +236,10 @@ def test_single_model(args):
235236 model = get_model (args )
236237 model .eval ()
237238
239+ test_compiler_util .print_basic_config (
240+ args , get_hardward_name (args ), get_compile_framework_version (args )
241+ )
242+
238243 # Run on eager mode
239244 running_eager_success = False
240245 try :
0 commit comments