diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index 3433655303..06d2d90e1d 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -128,6 +128,7 @@ def forward(ctx, a, b, c, acc_dtype=None): [512, 32768, 8192], [1024, 28672, 8192], [3072, 4096, 3072], + [4096, 4096, 4096], ], line_arg='provider', # argument name whose value corresponds to a different line in the plot @@ -152,7 +153,7 @@ def benchmark(M, N, K, provider): _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10, quantiles=quantiles) elif provider == 'triton': - c = torch.empty((M, N), device='xpu', dtype=torch.float32) + c = torch.zeros((M, N), device='xpu', dtype=torch.float32) triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 @@ -160,9 +161,9 @@ def benchmark(M, N, K, provider): _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, kernel_name='_kernel') elif provider == 'xetla': - c = torch.empty((M, N), device='xpu', dtype=torch.float32) - acc = torch.empty((M, N), device='xpu', dtype=torch.float32) - cnt = torch.empty((M, N), device='xpu', dtype=torch.int32) + c = torch.zeros((M, N), device='xpu', dtype=torch.float32) + acc = torch.zeros((M, N), device='xpu', dtype=torch.float32) + cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32) name = f'gemm_splitk_shape_{M}_{K}_{N}' func = getattr(xetla_kernel, name) diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index 1389eb9eb1..12f37e9d31 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -275,7 +275,7 @@ def benchmark(M, N, K, provider): _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10, quantiles=quantiles) elif provider == 'triton': - c = torch.empty((M, N), device=a.device, dtype=torch.float32) + c = torch.zeros((M, N), device=a.device, dtype=torch.float32) triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch') @@ -283,9 +283,9 @@ def benchmark(M, N, K, provider): quantiles=quantiles, kernel_name=['first_wave', 'full_tiles']) elif provider == 'xetla': - c = torch.empty((M, N), device='xpu', dtype=torch.float32) - acc = torch.empty((M, N), device='xpu', dtype=torch.float32) - cnt = torch.empty((M, N), device='xpu', dtype=torch.int32) + c = torch.zeros((M, N), device='xpu', dtype=torch.float32) + acc = torch.zeros((M, N), device='xpu', dtype=torch.float32) + cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32) name = f'gemm_streamk_shape_{M}_{K}_{N}' func = getattr(xetla_kernel, name) diff --git a/benchmarks/xetla_kernel/python_main.cpp b/benchmarks/xetla_kernel/python_main.cpp index 80dc03ef51..bad40faa78 100644 --- a/benchmarks/xetla_kernel/python_main.cpp +++ b/benchmarks/xetla_kernel/python_main.cpp @@ -317,6 +317,9 @@ PYBIND11_MODULE(xetla_kernel, m) { m.def("gemm_splitk_shape_3072_4096_3072", &bf16_split_k_gemm<3072, 4096, 3072, kslicing_impl_t::global>, "bf16_gemm_splitk (XeTLA)"); + m.def("gemm_splitk_shape_4096_4096_4096", + &bf16_split_k_gemm<4096, 4096, 4096, kslicing_impl_t::global>, + "bf16_gemm_splitk (XeTLA)"); // flash_attn m.def("flash_attn_causal_false", &flash_attn, "flash attn fwd (XeTLA)");