1313import triton .language as tl
1414
1515import triton_kernels_benchmark as benchmark_suit
16+ from triton_kernels_benchmark .benchmark_testing import do_bench_elapsed_time , BENCHMARKING_METHOD
17+
1618import xetla_kernel
1719
1820if benchmark_suit .USE_IPEX_OPTION :
@@ -250,9 +252,9 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
250252 line_arg = 'provider' ,
251253 # argument name whose value corresponds to a different line in the plot
252254 # possible values for `line_arg``
253- line_vals = ['triton' ] + (['xetla' ] if use_xetla else []),
255+ line_vals = ['triton' ] + (['xetla' ] if use_xetla else ['onednn' ]),
254256 # label name for the lines
255- line_names = ['Triton' ] + (['XeTLA' ] if use_xetla else []),
257+ line_names = ['Triton' ] + (['XeTLA' ] if use_xetla else ['onednn' ]),
256258 # line styles
257259 styles = [('green' , '-' ), ('green' , '--' ), ('blue' , '-' ), ('blue' , '--' )],
258260 ylabel = ['GB/s' , 'TFlops' ], # label name for the y-axis
@@ -277,8 +279,12 @@ def benchmark(B, M, N, K, provider):
277279 torch_b = torch .transpose (torch_b , - 2 , - 1 )
278280
279281 if provider == 'onednn' :
280- _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (lambda : torch .matmul (torch_a , torch_b ), warmup = 10 ,
281- rep = 10 , quantiles = quantiles )
282+ do_bench = benchmark_suit .do_bench
283+ if BENCHMARKING_METHOD == 'PYTORCH_LEGACY_PROFILER_USING_IPEX' :
284+ # Legacy profiler shows ~6000TFLOPS GeoMean for onednn measurements, so use more reliable method
285+ do_bench = do_bench_elapsed_time
286+ _ , min_ms , max_ms , mean_ms , cv = do_bench (lambda : torch .matmul (torch_a , torch_b ), warmup = 10 , rep = 10 ,
287+ quantiles = quantiles , kernel_name = 'gemm_kernel' )
282288 elif provider == 'triton' :
283289 assert len (a .shape ) == len (b .shape ), 'Incompatible sizes'
284290 if len (a .shape ) == 3 :
0 commit comments