Skip to content

Commit 97ba3b3

Browse files
authored
[Bench][AMD] Add Assumptions to Enable Buffer Ops (#7742)
This PR added `tl.assume` to the kernel to make the compiler change global loads to buffer loads. Worth mentioning that this only works for weights and scales of weights.
1 parent bfc04bc commit 97ba3b3

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def swizzle2d(pid, grid_m, grid_n, GROUP_M: tl.constexpr):
4848
width = GROUP_M * grid_n
4949
group_id = pid // width
5050
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
51+
tl.assume(group_size >= 0)
5152
pid_m = group_id * GROUP_M + (pid % group_size)
5253
pid_n = (pid % width) // (group_size)
5354
return pid_m, pid_n

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,27 @@ def _matmul_ogs(
8080
SWAP_XW: tl.constexpr = False,
8181
IS_EPILOGUE_DEQUANT_MXFP8: tl.constexpr = False):
8282

83+
tl.assume(stride_y_k >= 0)
84+
tl.assume(stride_y_z >= 0)
85+
tl.assume(stride_y_m >= 0)
86+
tl.assume(stride_y_n >= 0)
87+
tl.assume(stride_x_z >= 0)
88+
tl.assume(stride_x_m >= 0)
89+
tl.assume(stride_x_k >= 0)
90+
tl.assume(stride_w_e >= 0)
91+
tl.assume(stride_w_k >= 0)
92+
tl.assume(stride_w_n >= 0)
93+
if stride_w_mx_e is not None:
94+
tl.assume(stride_w_mx_e >= 0)
95+
if stride_w_mx_k is not None:
96+
tl.assume(stride_w_mx_k >= 0)
97+
if stride_w_mx_n is not None:
98+
tl.assume(stride_w_mx_n >= 0)
99+
tl.assume(stride_b_e >= 0)
100+
tl.assume(batch_size >= 0)
101+
tl.assume(grid_m >= 0)
102+
tl.assume(grid_n >= 0)
103+
83104
is_w_microscaled: tl.constexpr = WMxScale is not None
84105
MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
85106
if is_w_microscaled:
@@ -116,7 +137,9 @@ def _matmul_ogs(
116137
HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None
117138
index_type: tl.constexpr = tl.int64 if UPCAST_INDICES else tl.int32
118139

119-
total_actual_tiles = batch_size * (grid_m - padding_m) * grid_n * SPLIT_K
140+
unpadded_m = grid_m - padding_m
141+
tl.assume(unpadded_m >= 0)
142+
total_actual_tiles = batch_size * unpadded_m * grid_n * SPLIT_K
120143
if padding_m > 0 and pid >= total_actual_tiles:
121144
tl.device_assert(batch_size == 0)
122145
pid_mn = pid - total_actual_tiles
@@ -132,11 +155,11 @@ def _matmul_ogs(
132155
pid_emnk = pid
133156
if XCD_SWIZZLE != 1:
134157
pid_emnk = xcd_swizzle(pid_emnk, total_actual_tiles, XCD_SWIZZLE)
135-
pid_e = pid_emnk // ((grid_m - padding_m) * grid_n * SPLIT_K)
136-
pid_mnk = pid_emnk % ((grid_m - padding_m) * grid_n * SPLIT_K)
158+
pid_e = pid_emnk // (unpadded_m * grid_n * SPLIT_K)
159+
pid_mnk = pid_emnk % (unpadded_m * grid_n * SPLIT_K)
137160
pid_k = pid_mnk % SPLIT_K
138161
pid_mn = pid_mnk // SPLIT_K
139-
pid_m, pid_n = swizzle2d(pid_mn, (grid_m - padding_m), grid_n, GROUP_M)
162+
pid_m, pid_n = swizzle2d(pid_mn, unpadded_m, grid_n, GROUP_M)
140163
# For split-k, advance to the output k slice
141164
if SPLIT_K > 1:
142165
Y += pid_k.to( index_type) * stride_y_k

0 commit comments

Comments
 (0)