Skip to content

Commit e89244d

Browse files
Rename GRF mode values
Signed-off-by: Whitney Tsang <[email protected]>
1 parent d478b30 commit e89244d

File tree

15 files changed

+73
-71
lines changed

15 files changed

+73
-71
lines changed

benchmarks/third_party/sglang/scaled_mm_benchmark.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,20 @@ def is_weak_contiguous(x: torch.Tensor):
2323

2424
def get_matmul_batched_autotune_configs() -> List[triton.Config]:
2525
configs = [
26-
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'grf_mode': 'large'}, num_stages=s, num_warps=32)
26+
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'grf_mode': '256'}, num_stages=s, num_warps=32)
2727
for s in [2, 3]
2828
] + [
2929
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': m}, num_stages=s, num_warps=w)
3030
for s in [2]
31-
for (m, w) in ([('large', 32), ('small', 64)])
31+
for (m, w) in ([('256', 32), ('128', 64)])
3232
] + [
33-
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': 'large'}, num_stages=s, num_warps=32)
33+
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': '256'}, num_stages=s, num_warps=32)
3434
for s in [2]
3535
] + [
36-
triton.Config({'BLOCK_M': 8, 'BLOCK_N': 512, 'BLOCK_K': 64, 'grf_mode': 'large'}, num_stages=s, num_warps=32)
36+
triton.Config({'BLOCK_M': 8, 'BLOCK_N': 512, 'BLOCK_K': 64, 'grf_mode': '256'}, num_stages=s, num_warps=32)
3737
for s in [2]
3838
] + [
39-
triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': 'large'}, num_stages=s, num_warps=4)
39+
triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': '256'}, num_stages=s, num_warps=4)
4040
for s in [2]
4141
]
4242
return configs

benchmarks/third_party/vllm/batched_moe_benchmark.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,20 +201,20 @@ def expert_triton_kernel(
201201

202202
def get_matmul_batched_autotune_configs() -> List[triton.Config]:
203203
configs = [
204-
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'grf_mode': 'large'}, num_stages=s, num_warps=32)
204+
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'grf_mode': '256'}, num_stages=s, num_warps=32)
205205
for s in [2, 3]
206206
] + [
207207
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': m}, num_stages=s, num_warps=w)
208208
for s in [2]
209-
for (m, w) in ([('large', 32), ('small', 64)])
209+
for (m, w) in ([('256', 32), ('128', 64)])
210210
] + [
211-
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': 'large'}, num_stages=s, num_warps=32)
211+
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': '256'}, num_stages=s, num_warps=32)
212212
for s in [2]
213213
] + [
214-
triton.Config({'BLOCK_M': 8, 'BLOCK_N': 512, 'BLOCK_K': 64, 'grf_mode': 'large'}, num_stages=s, num_warps=32)
214+
triton.Config({'BLOCK_M': 8, 'BLOCK_N': 512, 'BLOCK_K': 64, 'grf_mode': '256'}, num_stages=s, num_warps=32)
215215
for s in [2]
216216
] + [
217-
triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': 'large'}, num_stages=s, num_warps=4)
217+
triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': '256'}, num_stages=s, num_warps=4)
218218
for s in [2]
219219
]
220220
return configs

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _attn_fwd_with_block_pointers(Q, K, V, sm_scale, M, Out, #
171171

172172

173173
configs = [
174-
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large'}, num_stages=s, num_warps=w) \
174+
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': '256'}, num_stages=s, num_warps=w) \
175175
for BM in [128, 256] \
176176
for BN in [32, 64] \
177177
for s in [2, 3, 4] \

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@
2121
def get_matmul_autotune_configs() -> List[triton.Config]:
2222
configs = [
2323
triton.Config(
24-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
24+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
2525
num_stages=s, num_warps=32) for s in [1, 2, 3]
2626
] + [
2727
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
28-
num_stages=s, num_warps=w) for s in [2, 3, 4] for (m, w) in ([('large', 32), ('small', 64)])
28+
num_stages=s, num_warps=w) for s in [2, 3, 4] for (m, w) in ([('256', 32), ('128', 64)])
2929
] + [
3030
triton.Config(
31-
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
31+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
3232
num_stages=s, num_warps=32) for s in [2]
3333
] + [
3434
triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': m},
35-
num_stages=s, num_warps=w) for s in [2, 3] for (m, w) in ([('large', 32), ('small', 64)])
35+
num_stages=s, num_warps=w) for s in [2, 3] for (m, w) in ([('256', 32), ('128', 64)])
3636
]
3737
return configs
3838

@@ -88,26 +88,26 @@ def matmul_kernel_with_block_pointers(
8888
def get_matmul_batched_autotune_configs() -> List[triton.Config]:
8989
configs = [
9090
triton.Config(
91-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
91+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
9292
num_stages=s, num_warps=32) for s in [2, 3]
9393
] + [
9494
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
95-
num_stages=s, num_warps=w) for s in [2] for (m, w) in ([('large', 32), ('small', 64)])
95+
num_stages=s, num_warps=w) for s in [2] for (m, w) in ([('256', 32), ('128', 64)])
9696
] + [
9797
triton.Config(
98-
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 1024, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
98+
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 1024, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
9999
num_stages=s, num_warps=32) for s in [2, 3]
100100
] + [
101101
triton.Config(
102-
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
102+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
103103
num_stages=s, num_warps=32) for s in [2]
104104
] + [
105105
triton.Config(
106-
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
106+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'},
107107
num_stages=s, num_warps=32) for s in [2]
108108
] + [
109109
triton.Config(
110-
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
110+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'},
111111
num_stages=s, num_warps=4) for s in [2]
112112
]
113113
return configs

benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,19 @@ def suffix():
3737
@triton.autotune(
3838
configs=[
3939
triton.Config(
40-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
40+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
4141
num_stages=2, num_warps=32),
4242
triton.Config(
43-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
43+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
4444
num_stages=3, num_warps=32),
4545
triton.Config(
46-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
46+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
4747
num_stages=2, num_warps=32),
4848
triton.Config(
49-
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
49+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
5050
num_stages=2, num_warps=32),
5151
triton.Config(
52-
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
52+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'},
5353
num_stages=2, num_warps=32),
5454
],
5555
key=['M', 'N', 'K'],
@@ -105,22 +105,22 @@ def matmul_kernel_with_tensor_descriptors(
105105
@triton.autotune(
106106
configs=[
107107
triton.Config(
108-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
108+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
109109
num_stages=2, num_warps=32),
110110
triton.Config(
111-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
111+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
112112
num_stages=3, num_warps=32),
113113
triton.Config(
114-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
114+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
115115
num_stages=2, num_warps=32),
116116
triton.Config(
117-
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
117+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
118118
num_stages=2, num_warps=32),
119119
triton.Config(
120-
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
120+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'},
121121
num_stages=2, num_warps=32),
122122
triton.Config(
123-
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
123+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'},
124124
num_stages=2, num_warps=4),
125125
],
126126
key=['M', 'N', 'K'],

benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,19 @@ def gelu(x):
3535
@triton.autotune(
3636
configs=[
3737
triton.Config(
38-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
38+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
3939
num_stages=2, num_warps=32),
4040
triton.Config(
41-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
41+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
4242
num_stages=3, num_warps=32),
4343
triton.Config(
44-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
44+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
4545
num_stages=2, num_warps=32),
4646
triton.Config(
47-
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
47+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
4848
num_stages=2, num_warps=32),
4949
triton.Config(
50-
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
50+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'},
5151
num_stages=2, num_warps=32),
5252
],
5353
key=['M', 'N', 'K'],
@@ -97,22 +97,22 @@ def matmul_kernel_with_tensor_descriptors(
9797
@triton.autotune(
9898
configs=[
9999
triton.Config(
100-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
100+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
101101
num_stages=2, num_warps=32),
102102
triton.Config(
103-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
103+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
104104
num_stages=3, num_warps=32),
105105
triton.Config(
106-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
106+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
107107
num_stages=2, num_warps=32),
108108
triton.Config(
109-
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
109+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
110110
num_stages=2, num_warps=32),
111111
triton.Config(
112-
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
112+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'},
113113
num_stages=2, num_warps=32),
114114
triton.Config(
115-
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
115+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'},
116116
num_stages=2, num_warps=4),
117117
],
118118
key=['M', 'N', 'K'],

benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,19 @@
1717
@triton.autotune(
1818
configs=[
1919
triton.Config(
20-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
20+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
2121
num_stages=2, num_warps=32),
2222
triton.Config(
23-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
23+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
2424
num_stages=3, num_warps=32),
2525
triton.Config(
26-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
26+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
2727
num_stages=2, num_warps=32),
2828
triton.Config(
29-
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
29+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
3030
num_stages=2, num_warps=32),
3131
triton.Config(
32-
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
32+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'},
3333
num_stages=2, num_warps=32),
3434
],
3535
key=['M', 'N', 'K'],
@@ -82,22 +82,22 @@ def matmul_kernel_with_tensor_descriptors(
8282
@triton.autotune(
8383
configs=[
8484
triton.Config(
85-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
85+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
8686
num_stages=2, num_warps=32),
8787
triton.Config(
88-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
88+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
8989
num_stages=3, num_warps=32),
9090
triton.Config(
91-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
91+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
9292
num_stages=2, num_warps=32),
9393
triton.Config(
94-
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
94+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
9595
num_stages=2, num_warps=32),
9696
triton.Config(
97-
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
97+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'},
9898
num_stages=2, num_warps=32),
9999
triton.Config(
100-
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'},
100+
{'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'},
101101
num_stages=2, num_warps=4),
102102
],
103103
key=['M', 'N', 'K'],

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
@triton.autotune(
1616
configs=[
17-
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 4, 'SPLIT_K': 4, 'grf_mode': 'large'},
17+
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 4, 'SPLIT_K': 4, 'grf_mode': '256'},
1818
num_stages=4, num_warps=32),
1919
],
2020
key=['M', 'N', 'K'],

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def mac_loop(
9696
@triton.autotune(
9797
configs=[
9898
triton.Config(
99-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
99+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
100100
num_stages=2, num_warps=32),
101101
],
102102
key=['M', 'N', 'K'],
@@ -132,7 +132,7 @@ def first_wave(
132132
@triton.autotune(
133133
configs=[
134134
triton.Config(
135-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
135+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'},
136136
num_stages=2, num_warps=32),
137137
],
138138
key=['M', 'N', 'K'],

python/test/unit/language/test_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1299,7 +1299,7 @@ def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bo
12991299
if is_xpu() and (128, 256, 256) == (BLOCK_M, BLOCK_N, BLOCK_K) and not CONST_SCALE and not PACK_B_ALONG_K:
13001300
kernel_kwargs["num_warps"] = 8
13011301
if is_xpu():
1302-
kernel_kwargs["grf_mode"] = "large"
1302+
kernel_kwargs["grf_mode"] = "256"
13031303
out = mxfp8_mxfp4_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, stride_scale, a.stride(0), a.stride(1),
13041304
b.stride(0), b.stride(1), output.stride(0), output.stride(1), not CONST_SCALE,
13051305
dtype_converter[A_DATA_TYPE], dtype_converter[B_DATA_TYPE], BLOCK_M, BLOCK_N,

0 commit comments

Comments
 (0)