Skip to content

Commit d82e3ea

Browse files
Cherry-pick from #3026
1 parent 352de9d commit d82e3ea

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,19 +305,28 @@ def benchmark(B, M, N, K, provider):
305305
elif provider == 'xetla':
306306
if B == 1:
307307
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
308-
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
309308
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)
310309
else:
311310
c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
312-
acc = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
313311
cnt = torch.zeros((B, M, N), device='xpu', dtype=torch.int32)
314312
name = f'gemm_shape_{B}_{M}_{K}_{N}'
315313
# FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
316314
# better performance.
317315
if (B, M, N, K) == (1, 3072, 3072, 4096):
318316
name = 'gemm_streamk_shape_3072_4096_3072'
319317
func = getattr(xetla_kernel, name)
320-
xetla_fn = lambda: func(a, b, c, acc, cnt)
318+
319+
320+
def xetla_func_with_acc_allocation():
321+
# allocating `acc` matrix on every function call, to be as similar as
322+
# possible to the triton kernel, which also does this on every call.
323+
if B == 1:
324+
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
325+
else:
326+
acc = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
327+
return func(a, b, c, acc, cnt)
328+
329+
xetla_fn = xetla_func_with_acc_allocation
321330
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
322331

323332
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')

0 commit comments

Comments
 (0)