1010import triton .language as tl
1111
1212import triton_kernels_benchmark as benchmark_suit
13+ import xetla_kernel
1314
1415if benchmark_suit .USE_IPEX_OPTION :
1516 import intel_extension_for_pytorch # type: ignore # noqa: F401
@@ -253,9 +254,9 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
253254 line_arg = 'provider' ,
254255 # argument name whose value corresponds to a different line in the plot
255256 # possible values for `line_arg``
256- line_vals = ['triton' ],
257+ line_vals = ['triton' , 'xetla' ],
257258 # label name for the lines
258- line_names = ['Triton' ],
259+ line_names = ['Triton' , 'XeTLA' ],
259260 # line styles
260261 styles = [('green' , '-' ), ('green' , '--' ), ('blue' , '-' ), ('blue' , '--' )],
261262 ylabel = ['GB/s' , 'TFlops' ], # label name for the y-axis
@@ -280,6 +281,20 @@ def benchmark(M, N, K, provider):
280281 benchmark_suit .assert_close (triton_fn (), torch_fn (), atol = 1e-4 , rtol = 1e-2 , err_msg = 'triton to torch' )
281282 _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (triton_fn , warmup = 10 , rep = 10 , quantiles = quantiles ,
282283 kernel_name = ['first_wave' , 'full_tiles' ])
284+ elif provider == 'xetla' :
285+ c = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
286+ acc = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
287+ cnt = torch .empty ((M , N ), device = 'xpu' , dtype = torch .int32 )
288+
289+ name = f'gemm_streamk_shape_{ M } _{ K } _{ N } '
290+ func = getattr (xetla_kernel , name )
291+ xetla_fn = lambda : func (a , b , c , acc , cnt )
292+ torch_fn = lambda : torch .matmul (a , b ).to (torch .float32 )
293+
294+ # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
295+ _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (
296+ xetla_fn , warmup = 10 , rep = 10 , quantiles = quantiles ,
297+ kernel_name = 'gpu::xetla::kernel::gemm_universal_t<dispatch_stream_k' )
283298 else :
284299 raise NotImplementedError (f'Unsupported provider { provider } ' )
285300
0 commit comments