Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ def benchmark(B, M, N, K, provider):
acc = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
cnt = torch.empty((B, M, N), device='xpu', dtype=torch.int32)
name = f'gemm_shape_{B}_{M}_{K}_{N}'
# FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
# better performance.
if (B, M, N, K) == (1, 3072, 4096, 3072):
name = 'gemm_streamk_shape_3072_4096_3072'
func = getattr(xetla_kernel, name)
xetla_fn = lambda: func(a, b, c, acc, cnt)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
Expand Down