Skip to content

Commit 932d0be

Browse files
authored
Fix assertion error on gemm_splitk_benchmark.py (#2717)
The cause of this error is that we should initialize the results tensor when using `atomic_add`, otherwise it would read dirty memory from previous benchmarking cases.
1 parent 8382e76 commit 932d0be

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def forward(ctx, a, b, c, acc_dtype=None):
128128
[512, 32768, 8192],
129129
[1024, 28672, 8192],
130130
[3072, 4096, 3072],
131+
[4096, 4096, 4096],
131132
],
132133
line_arg='provider',
133134
# argument name whose value corresponds to a different line in the plot
@@ -152,17 +153,17 @@ def benchmark(M, N, K, provider):
152153
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
153154
quantiles=quantiles)
154155
elif provider == 'triton':
155-
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
156+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
156157
triton_fn = lambda: matmul(a, b, c)
157158
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
158159
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
159160
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
160161
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
161162
quantiles=quantiles, kernel_name='_kernel')
162163
elif provider == 'xetla':
163-
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
164-
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
165-
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
164+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
165+
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
166+
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)
166167

167168
name = f'gemm_splitk_shape_{M}_{K}_{N}'
168169
func = getattr(xetla_kernel, name)

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,17 +275,17 @@ def benchmark(M, N, K, provider):
275275
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
276276
quantiles=quantiles)
277277
elif provider == 'triton':
278-
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
278+
c = torch.zeros((M, N), device=a.device, dtype=torch.float32)
279279
triton_fn = lambda: matmul(a, b, c)
280280
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
281281
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
282282
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
283283
quantiles=quantiles,
284284
kernel_name=['first_wave', 'full_tiles'])
285285
elif provider == 'xetla':
286-
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
287-
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
288-
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
286+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
287+
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
288+
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)
289289

290290
name = f'gemm_streamk_shape_{M}_{K}_{N}'
291291
func = getattr(xetla_kernel, name)

benchmarks/xetla_kernel/python_main.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,9 @@ PYBIND11_MODULE(xetla_kernel, m) {
317317
m.def("gemm_splitk_shape_3072_4096_3072",
318318
&bf16_split_k_gemm<3072, 4096, 3072, kslicing_impl_t::global>,
319319
"bf16_gemm_splitk (XeTLA)");
320+
m.def("gemm_splitk_shape_4096_4096_4096",
321+
&bf16_split_k_gemm<4096, 4096, 4096, kslicing_impl_t::global>,
322+
"bf16_gemm_splitk (XeTLA)");
320323
// flash_attn
321324
m.def("flash_attn_causal_false", &flash_attn<false, false, false>,
322325
"flash attn fwd (XeTLA)");

0 commit comments

Comments
 (0)