Skip to content

Commit 7ad67a0

Browse files
[GEMM] Add autotune configs of num_warps=64 (#3297)
We can see from https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/13004109008 that those added configs are selected as best autotune configs for some input shapes. The actual geomean performance gain is less than 2%. The biggest per shape improvement is 12%. Below is the breakdown for improvement per input shape: B | M | N | K |  ratio -- | -- | -- | -- | -- 1 | 1 | 13824 | 5120 | 1.07916 1 | 4 | 12288 | 4096 | 1.124677 1 | 512 | 8192 | 8192 | 1.081075 1 | 512 | 8192 | 32768 | 1.02523 1 | 8192 | 1024 | 16384 | 1.028267 1 | 16384 | 1024 | 8192 | 1.094302 1 | 16384 | 8192 | 1024 | 1.051715 4 | 32768 | 128 | 4096 | 1.019839 4 | 32768 | 4096 | 128 | 1.023059 32 | 4096 | 128 | 4096 | 1.019211 Signed-off-by: Whitney Tsang <[email protected]>
1 parent 393f700 commit 7ad67a0

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,17 @@
2626
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
2727
num_stages=s, num_warps=32) for s in [1, 2, 3]
2828
] + [
29-
triton.Config(
30-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
31-
num_stages=s, num_warps=32) for s in [2, 3, 4]
29+
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
30+
num_stages=s, num_warps=w) for s in [2, 3, 4] for (m, w) in
31+
([('large', 32), ('small', 64)] if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' else [('large', 32)])
3232
] + [
3333
triton.Config(
3434
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
3535
num_stages=s, num_warps=32) for s in [2]
3636
] + [
37-
triton.Config(
38-
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
39-
num_stages=s, num_warps=32) for s in [2, 3]
37+
triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': m},
38+
num_stages=s, num_warps=w) for s in [2, 3] for (m, w) in
39+
([('large', 32), ('small', 64)] if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' else [('large', 32)])
4040
],
4141
key=['M', 'N', 'K'],
4242
)
@@ -91,9 +91,9 @@ def matmul_kernel_with_block_pointers(
9191
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
9292
num_stages=s, num_warps=32) for s in [2, 3]
9393
] + [
94-
triton.Config(
95-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
96-
num_stages=s, num_warps=32) for s in [2]
94+
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
95+
num_stages=s, num_warps=w) for s in [2] for (m, w) in
96+
([('large', 32), ('small', 64)] if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' else [('large', 32)])
9797
] + [
9898
triton.Config(
9999
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 1024, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},

0 commit comments

Comments
 (0)