Skip to content

Commit 4e81944

Browse files
[GEMM] Sort benchmark shapes (#4516)
It is easier to plot the chart with performance shapes when they are sorted. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent f223b65 commit 4e81944

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -233,23 +233,36 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
233233

234234

235235
X_VALS = [ #
236-
[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]
237-
] + [ #
238-
[1, m, n, 4096] for m in [1, 8] for n in [1024, 4096, 6144, 14336, 28672, 128256]
239-
] + [ #
240-
[1, m, 4096, 14336] for m in [1, 8]
241-
] + [ #
236+
[1, 1, 1024, 4096],
237+
[1, 1, 4096, 4096],
238+
[1, 1, 4096, 14336],
239+
[1, 1, 6144, 4096],
242240
[1, 1, 13824, 5120],
241+
[1, 1, 14336, 4096],
242+
[1, 1, 28672, 4096],
243+
[1, 1, 128256, 4096],
243244
[1, 4, 12288, 4096],
245+
[1, 8, 1024, 4096],
246+
[1, 8, 4096, 4096],
247+
[1, 8, 4096, 14336],
248+
[1, 8, 6144, 4096],
249+
[1, 8, 14336, 4096],
250+
[1, 8, 28672, 4096],
251+
[1, 8, 128256, 4096],
244252
[1, 512, 8192, 8192],
245253
[1, 512, 8192, 32768],
246254
[1, 512, 32768, 8192],
255+
[1, 1024, 1024, 1024],
247256
[1, 1024, 8192, 16384],
248257
[1, 1024, 8192, 28672],
258+
[1, 2048, 2048, 2048],
249259
[1, 3072, 3072, 4096], # FIXME: Remove this case when gemm_streamk_benchmark can get better performance
260+
[1, 4096, 4096, 4096],
250261
[1, 4096, 8192, 16384],
251262
[1, 8192, 1024, 16384],
263+
[1, 8192, 4096, 4096],
252264
[1, 8192, 4096, 16384],
265+
[1, 8192, 8192, 8192],
253266
[1, 16384, 1024, 8192],
254267
[1, 16384, 4096, 8192],
255268
[1, 16384, 8192, 1024],
@@ -259,7 +272,6 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
259272
[32, 4096, 128, 4096],
260273
[4096, 8, 128, 16384],
261274
[4096, 8, 16384, 128],
262-
[1, 8192, 4096, 4096],
263275
]
264276

265277
DEVICE_NAME = torch.xpu.get_device_name()

0 commit comments

Comments
 (0)