Skip to content

Commit 02346d9

Browse files
[GEMM] Fix typo for streamk 3072x4096x3072 case (#2898)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent e7292af commit 02346d9

File tree

1 file changed

+20
-19
lines changed

1 file changed

+20
-19
lines changed

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -227,28 +227,28 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
227227
@benchmark_suit.perf_report(
228228
benchmark_suit.Benchmark(
229229
# argument names to use as an x-axis for the plot
230-
x_names=['B', 'M', 'K', 'N'],
230+
x_names=['B', 'M', 'N', 'K'],
231231
# different possible values for `x_name`
232232
x_vals=[[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + #
233233
[ #
234-
[1, 1, 5120, 13824], #
235-
[1, 4, 4096, 12288], #
234+
[1, 1, 13824, 5120], #
235+
[1, 4, 12288, 4096], #
236236
[1, 512, 8192, 8192], #
237237
[1, 512, 8192, 32768], #
238238
[1, 512, 32768, 8192], #
239-
[1, 1024, 16384, 8192], #
240-
[1, 1024, 28672, 8192], #
241-
[1, 3072, 4096, 3072], # FIXME: Remove this case when gemm_streamk_benchmark can get better performance
242-
[1, 4096, 16384, 8192], #
243-
[1, 8192, 16384, 1024], #
244-
[1, 8192, 16384, 4096], #
239+
[1, 1024, 8192, 16384], #
240+
[1, 1024, 8192, 28672], #
241+
[1, 3072, 3072, 4096], # FIXME: Remove this case when gemm_streamk_benchmark can get better performance
242+
[1, 4096, 8192, 16384], #
243+
[1, 8192, 1024, 16384], #
244+
[1, 8192, 4096, 16384], #
245245
[1, 16384, 1024, 8192], #
246246
[1, 16384, 4096, 8192], #
247247
[1, 16384, 8192, 1024], #
248248
[1, 16384, 8192, 4096], #
249249
[4, 32768, 128, 4096], #
250250
[4, 32768, 4096, 128], #
251-
[32, 4096, 4096, 128], #
251+
[32, 4096, 128, 4096], #
252252
[4096, 8, 128, 16384], #
253253
[4096, 8, 16384, 128]
254254
],
@@ -268,6 +268,7 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
268268
def benchmark(B, M, N, K, provider):
269269
a_shape, b_shape = get_shapes(B, M, N, K, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B)
270270

271+
torch.manual_seed(0)
271272
a = torch.rand(a_shape, device='xpu', dtype=torch.bfloat16)
272273
b = torch.rand(b_shape, device='xpu', dtype=torch.bfloat16)
273274

@@ -291,10 +292,10 @@ def benchmark(B, M, N, K, provider):
291292
elif provider == 'triton':
292293
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
293294
if len(a.shape) == 3:
294-
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
295+
c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
295296
else:
296297
assert len(a.shape) == 2, 'Expecting shape of length 2'
297-
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
298+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
298299
triton_fn = lambda: matmul(a, b, c, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B)
299300
torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32)
300301
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
@@ -304,17 +305,17 @@ def benchmark(B, M, N, K, provider):
304305
kernel_name='matmul_kernel_with_block_pointers')
305306
elif provider == 'xetla':
306307
if B == 1:
307-
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
308-
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
309-
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
308+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
309+
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
310+
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)
310311
else:
311-
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
312-
acc = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
313-
cnt = torch.empty((B, M, N), device='xpu', dtype=torch.int32)
312+
c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
313+
acc = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
314+
cnt = torch.zeros((B, M, N), device='xpu', dtype=torch.int32)
314315
name = f'gemm_shape_{B}_{M}_{K}_{N}'
315316
# FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
316317
# better performance.
317-
if (B, M, N, K) == (1, 3072, 4096, 3072):
318+
if (B, M, N, K) == (1, 3072, 3072, 4096):
318319
name = 'gemm_streamk_shape_3072_4096_3072'
319320
func = getattr(xetla_kernel, name)
320321
xetla_fn = lambda: func(a, b, c, acc, cnt)

0 commit comments

Comments
 (0)