@@ -309,6 +309,10 @@ def benchmark(B, M, N, K, provider):
309309 acc = torch .empty ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
310310 cnt = torch .empty ((B , M , N ), device = 'xpu' , dtype = torch .int32 )
311311 name = f'gemm_shape_{ B } _{ M } _{ K } _{ N } '
312+ # FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
313+ # better performance.
314+ if (B , M , N , K ) == (1 , 3072 , 4096 , 3072 ):
315+ name = 'gemm_streamk_shape_3072_4096_3072'
312316 func = getattr (xetla_kernel , name )
313317 xetla_fn = lambda : func (a , b , c , acc , cnt )
314318 torch_fn = lambda : torch .matmul (a , b ).to (torch .float32 )
@@ -338,6 +342,7 @@ def benchmark(B, M, N, K, provider):
338342 'gemm_shape_32_4096_4096_128' : 'Test_32x4096x4096x128_row_row' ,
339343 'gemm_shape_4096_8_128_16384' : 'Test_4096x8x128x16384_row_row' ,
340344 'gemm_shape_4096_8_16384_128' : 'Test_4096x8x16384x128_row_row' ,
345+ 'gemm_streamk_shape_3072_4096_3072' : 'stream_k_gemm_run' ,
341346 }
342347
343348 # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
0 commit comments