Skip to content

Commit 97cf333

Browse files
[TUTORIAL][GEMM] Update autotune configs (#3479)
The autotune configs are sync from [gemm_benchmark.py](https://github.com/intel/intel-xpu-backend-for-triton/blob/main/benchmarks/triton_kernels_benchmark/gemm_benchmark.py#L25). Signed-off-by: Whitney Tsang <[email protected]>
1 parent e6263f2 commit 97cf333

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

python/tutorials/10-experimental-block-pointer.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,22 +90,35 @@
9090
# Final Result
9191
# ------------
9292

93+
import os
94+
9395
import torch
9496

9597
import triton
9698
import triton.language as tl
9799

100+
SMALL_GRF = os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0'
101+
98102

99103
@triton.autotune(
100104
configs=[
101-
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2,
102-
num_warps=32),
103-
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=3,
104-
num_warps=32),
105-
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2,
106-
num_warps=32),
107-
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2,
108-
num_warps=32),
105+
triton.Config(
106+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
107+
num_stages=s, num_warps=32) for s in [1, 2, 3]
108+
] + [
109+
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
110+
num_stages=s, num_warps=w)
111+
for s in [2, 3, 4]
112+
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
113+
] + [
114+
triton.Config(
115+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
116+
num_stages=s, num_warps=32) for s in [2]
117+
] + [
118+
triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': m},
119+
num_stages=s, num_warps=w)
120+
for s in [2, 3]
121+
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
109122
],
110123
key=['M', 'N', 'K'],
111124
)

0 commit comments

Comments
 (0)