Skip to content

Commit faff365

Browse files
[GEMM] Undo small GRF autotune config for transpose A (#3323)
After #3297, we start to observe different kind of failures for `AtxB`, it is unclear if it is due to that change, but undoing for transpose A just in case.
1 parent 395bfb0 commit faff365

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
TRANSPOSE_A = os.getenv('TRANSPOSE_A', '0') == '1'
1919
TRANSPOSE_B = os.getenv('TRANSPOSE_B', '0') == '1'
2020
use_xetla = not (TRANSPOSE_A or TRANSPOSE_B)
21+
SMALL_GRF = os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' and not TRANSPOSE_A
2122

2223

2324
@triton.autotune(
@@ -27,16 +28,18 @@
2728
num_stages=s, num_warps=32) for s in [1, 2, 3]
2829
] + [
2930
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
30-
num_stages=s, num_warps=w) for s in [2, 3, 4] for (m, w) in
31-
([('large', 32), ('small', 64)] if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' else [('large', 32)])
31+
num_stages=s, num_warps=w)
32+
for s in [2, 3, 4]
33+
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
3234
] + [
3335
triton.Config(
3436
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
3537
num_stages=s, num_warps=32) for s in [2]
3638
] + [
3739
triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': m},
38-
num_stages=s, num_warps=w) for s in [2, 3] for (m, w) in
39-
([('large', 32), ('small', 64)] if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' else [('large', 32)])
40+
num_stages=s, num_warps=w)
41+
for s in [2, 3]
42+
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
4043
],
4144
key=['M', 'N', 'K'],
4245
)
@@ -92,8 +95,9 @@ def matmul_kernel_with_block_pointers(
9295
num_stages=s, num_warps=32) for s in [2, 3]
9396
] + [
9497
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
95-
num_stages=s, num_warps=w) for s in [2] for (m, w) in
96-
([('large', 32), ('small', 64)] if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' else [('large', 32)])
98+
num_stages=s, num_warps=w)
99+
for s in [2]
100+
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
97101
] + [
98102
triton.Config(
99103
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 1024, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},

0 commit comments

Comments
 (0)