diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index b70465ee71..7e0b339dc6 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -309,6 +309,10 @@ def benchmark(B, M, N, K, provider): acc = torch.empty((B, M, N), device='xpu', dtype=torch.float32) cnt = torch.empty((B, M, N), device='xpu', dtype=torch.int32) name = f'gemm_shape_{B}_{M}_{K}_{N}' + # FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get + # better performance. + if (B, M, N, K) == (1, 3072, 4096, 3072): + name = 'gemm_streamk_shape_3072_4096_3072' func = getattr(xetla_kernel, name) xetla_fn = lambda: func(a, b, c, acc, cnt) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) @@ -338,6 +342,7 @@ def benchmark(B, M, N, K, provider): 'gemm_shape_32_4096_4096_128': 'Test_32x4096x4096x128_row_row', 'gemm_shape_4096_8_128_16384': 'Test_4096x8x128x16384_row_row', 'gemm_shape_4096_8_16384_128': 'Test_4096x8x16384x128_row_row', + 'gemm_streamk_shape_3072_4096_3072': 'stream_k_gemm_run', } # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')