Skip to content

Commit 7f56c6d

Browse files
initialize results tensor
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 3f5c597 commit 7f56c6d

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
268268
def benchmark(B, M, N, K, provider):
269269
a_shape, b_shape = get_shapes(B, M, N, K, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B)
270270

271+
torch.manual_seed(0)
271272
a = torch.rand(a_shape, device='xpu', dtype=torch.bfloat16)
272273
b = torch.rand(b_shape, device='xpu', dtype=torch.bfloat16)
273274

@@ -291,10 +292,10 @@ def benchmark(B, M, N, K, provider):
291292
elif provider == 'triton':
292293
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
293294
if len(a.shape) == 3:
294-
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
295+
c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
295296
else:
296297
assert len(a.shape) == 2, 'Expecting shape of length 2'
297-
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
298+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
298299
triton_fn = lambda: matmul(a, b, c, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B)
299300
torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32)
300301
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
@@ -304,13 +305,13 @@ def benchmark(B, M, N, K, provider):
304305
kernel_name='matmul_kernel_with_block_pointers')
305306
elif provider == 'xetla':
306307
if B == 1:
307-
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
308-
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
309-
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
308+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
309+
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
310+
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)
310311
else:
311-
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
312-
acc = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
313-
cnt = torch.empty((B, M, N), device='xpu', dtype=torch.int32)
312+
c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
313+
acc = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
314+
cnt = torch.zeros((B, M, N), device='xpu', dtype=torch.int32)
314315
name = f'gemm_shape_{B}_{M}_{K}_{N}'
315316
# FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
316317
# better performance.

0 commit comments

Comments
 (0)