Skip to content

Commit 818e108

Browse files
[AMD] Update configs and add assume in matmul tutorial (#7443)
- Updated AMD HIP autotune configs - Added tl.assume for strides and pid_m and pid_n to help guide integer analysis
1 parent 0bd996b commit 818e108

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

python/tutorials/03-matrix-multiplication.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -205,23 +205,17 @@ def get_cuda_autotune_config():
205205

206206

207207
def get_hip_autotune_config():
208-
return [
209-
triton.Config(
210-
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
211-
num_warps=4, num_stages=2),
212-
triton.Config(
213-
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2},
214-
num_warps=8, num_stages=2),
215-
triton.Config(
216-
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
217-
num_warps=8, num_stages=2),
218-
triton.Config(
219-
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3},
220-
num_warps=4, num_stages=2),
221-
triton.Config(
222-
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8},
223-
num_warps=4, num_stages=2),
208+
sizes = [
209+
{'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6},
210+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4},
211+
{'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6},
212+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6},
213+
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4},
214+
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4},
215+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4},
216+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6},
224217
]
218+
return [triton.Config(s | {'matrix_instr_nonkdim': 16}, num_warps=8, num_stages=2) for s in sizes]
225219

226220

227221
def get_autotune_config():
@@ -274,6 +268,19 @@ def matmul_kernel(
274268
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
275269
pid_n = (pid % num_pid_in_group) // group_size_m
276270

271+
# -----------------------------------------------------------
272+
# Add some integer bound assumptions.
273+
# This helps to guide integer analysis in the backend to optimize
274+
# load/store offset address calculation
275+
tl.assume(pid_m >= 0)
276+
tl.assume(pid_n >= 0)
277+
tl.assume(stride_am > 0)
278+
tl.assume(stride_ak > 0)
279+
tl.assume(stride_bn > 0)
280+
tl.assume(stride_bk > 0)
281+
tl.assume(stride_cm > 0)
282+
tl.assume(stride_cn > 0)
283+
277284
# ----------------------------------------------------------
278285
# Create pointers for the first blocks of A and B.
279286
# We will advance this pointer as we move in the K direction

0 commit comments

Comments
 (0)