Skip to content

Commit f7c43d7

Browse files
[XeTLA] Use stream-k implementation by default for one shape
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 6018c7b commit f7c43d7

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,10 @@ def benchmark(B, M, N, K, provider):
309309
acc = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
310310
cnt = torch.empty((B, M, N), device='xpu', dtype=torch.int32)
311311
name = f'gemm_shape_{B}_{M}_{K}_{N}'
312+
# FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
313+
# better performance.
314+
if (B, M, N, K) == (1, 3072, 4096, 3072):
315+
name = 'gemm_streamk_shape_3072_4096_3072'
312316
func = getattr(xetla_kernel, name)
313317
xetla_fn = lambda: func(a, b, c, acc, cnt)
314318
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

0 commit comments

Comments
 (0)