Skip to content

Commit 26baece

Browse files
[XeTLA] Use stream-k implementation by default for 3072x4096x3072 (#2496)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent b79ceaa commit 26baece

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,10 @@ def benchmark(B, M, N, K, provider):
309309
acc = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
310310
cnt = torch.empty((B, M, N), device='xpu', dtype=torch.int32)
311311
name = f'gemm_shape_{B}_{M}_{K}_{N}'
312+
# FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
313+
# better performance.
314+
if (B, M, N, K) == (1, 3072, 4096, 3072):
315+
name = 'gemm_streamk_shape_3072_4096_3072'
312316
func = getattr(xetla_kernel, name)
313317
xetla_fn = lambda: func(a, b, c, acc, cnt)
314318
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
@@ -338,6 +342,7 @@ def benchmark(B, M, N, K, provider):
338342
'gemm_shape_32_4096_4096_128': 'Test_32x4096x4096x128_row_row',
339343
'gemm_shape_4096_8_128_16384': 'Test_4096x8x128x16384_row_row',
340344
'gemm_shape_4096_8_16384_128': 'Test_4096x8x16384x128_row_row',
345+
'gemm_streamk_shape_3072_4096_3072': 'stream_k_gemm_run',
341346
}
342347

343348
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')

0 commit comments

Comments
 (0)