1010import triton .language as tl
1111
1212import triton_kernels_benchmark as benchmark_suit
13+ import xetla_kernel
1314
1415if benchmark_suit .USE_IPEX_OPTION :
1516 import intel_extension_for_pytorch # type: ignore # noqa: F401
@@ -253,9 +254,9 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
253254 line_arg = 'provider' ,
254255 # argument name whose value corresponds to a different line in the plot
255256 # possible values for `line_arg``
256- line_vals = ['triton' ],
257+ line_vals = ['triton' , 'xetla' ],
257258 # label name for the lines
258- line_names = ['Triton' ],
259+ line_names = ['Triton' , 'XeTLA' ],
259260 # line styles
260261 styles = [('green' , '-' ), ('green' , '--' ), ('blue' , '-' ), ('blue' , '--' )],
261262 ylabel = ['GB/s' , 'TFlops' ], # label name for the y-axis
@@ -281,6 +282,20 @@ def benchmark(M, N, K, provider):
281282 _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 10 , n_repeat = 10 ,
282283 quantiles = quantiles ,
283284 kernel_name = ['first_wave' , 'full_tiles' ])
285+ elif provider == 'xetla' :
286+ c = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
287+ acc = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
288+ cnt = torch .empty ((M , N ), device = 'xpu' , dtype = torch .int32 )
289+
290+ name = f'gemm_streamk_shape_{ M } _{ K } _{ N } '
291+ func = getattr (xetla_kernel , name )
292+ xetla_fn = lambda : func (a , b , c , acc , cnt )
293+ torch_fn = lambda : torch .matmul (a , b ).to (torch .float32 )
294+
295+ # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
296+ _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (
297+ xetla_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles ,
298+ kernel_name = 'gpu::xetla::kernel::gemm_universal_t<dispatch_stream_k' )
284299 else :
285300 raise NotImplementedError (f'Unsupported provider { provider } ' )
286301
0 commit comments