Skip to content

Commit 7eb41bf

Browse files
authored
Make acc matrix allocation on each call for XeTLA GEMM benchmarks (#3026)
If we take for comparison: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/12382880184/job/34564504020 (main) vs https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/12390456716/job/34585505155 (PR), then the degradation from this pull request for XeTLA is ~3%. However, this is also a **potential opportunity** to improve the Triton kernel by only allocating the accumulation matrix once. If this is implemented for Triton, this pull request will need to be rolled back for XeTLA. --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 4e654b6 commit 7eb41bf

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,19 +306,27 @@ def benchmark(B, M, N, K, provider):
306306
elif provider == 'xetla':
307307
if B == 1:
308308
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
309-
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
310309
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)
311310
else:
312311
c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
313-
acc = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
314312
cnt = torch.zeros((B, M, N), device='xpu', dtype=torch.int32)
315313
name = f'gemm_shape_{B}_{M}_{K}_{N}'
316314
# FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
317315
# better performance.
318316
if (B, M, N, K) == (1, 3072, 3072, 4096):
319317
name = 'gemm_streamk_shape_3072_4096_3072'
320318
func = getattr(xetla_kernel, name)
321-
xetla_fn = lambda: func(a, b, c, acc, cnt)
319+
320+
def xetla_func_with_acc_allocation():
321+
# allocating `acc` matrix on every function call, to be as similar as
322+
# possible to the triton kernel, which also does this on every call.
323+
if B == 1:
324+
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
325+
else:
326+
acc = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
327+
return func(a, b, c, acc, cnt)
328+
329+
xetla_fn = xetla_func_with_acc_allocation
322330
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
323331

324332
kernels_name = {

0 commit comments

Comments
 (0)