Skip to content

Commit 4327b5b

Browse files
authored
[TRITON_KERNELS] pad tensors in HopperValue layout (#8677)
1 parent 240a5c8 commit 4327b5b

File tree

5 files changed

+18
-7
lines changed

5 files changed

+18
-7
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -376,9 +376,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gamm
376376
if torch.cuda.get_device_capability()[0] < 10:
377377
if "mxfloat4" not in weight_dtype_str:
378378
pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.")
379-
if k % 64 != 0 or n % 64 != 0:
380-
# Automatic padding not implemented for Hopper swizzle
381-
pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")
382379

383380
expt_is_inner = (inner_expt_opt is not None)
384381
if expt_is_inner:

python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def _load_tile_attrs(
7373
SPLIT_K: tl.constexpr,
7474
GROUP_M: tl.constexpr,
7575
XCD_SWIZZLE: tl.constexpr,
76+
SWIZZLE_MX_VALUE: tl.constexpr,
7677
):
7778
# unpack and swizzle program ids
7879
pid_emnk = tile_id
@@ -116,6 +117,8 @@ def _load_tile_attrs(
116117
K_W = K * (PACKED_BLOCK_K_W // BLOCK_K)
117118
else:
118119
K_W = K // (BLOCK_K // PACKED_BLOCK_K_W)
120+
if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
121+
K_W = tl.cdiv(K_W, 128) * 128
119122
k_tiles = tl.cdiv(K - off_k_x, BLOCK_K * SPLIT_K)
120123
if ExptData is None:
121124
tl.static_assert(M is not None)

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def _matmul_ogs(
222222
M, K, ExptData, ExptHist, ExptOffs, ExptTileOffs,
223223
EXPT_IS_INNER, X_IS_PADDED, W_IS_PADDED,
224224
BLOCK_M, BLOCK_K, PACKED_BLOCK_K_W, SPLIT_K,
225-
GROUP_M, XCD_SWIZZLE)
225+
GROUP_M, XCD_SWIZZLE, SWIZZLE_MX_VALUE)
226226

227227
# For split-k, advance to the output k slice
228228
if SPLIT_K > 1:
@@ -290,7 +290,10 @@ def _matmul_ogs(
290290

291291
# B pointers
292292
offs_w_n = pid_n * PACKED_BLOCK_N_W + tl.arange(0, PACKED_BLOCK_N_W)
293-
offs_w_n = tl.max_contiguous(tl.multiple_of(offs_w_n % (N // W_N_DIVISOR), PACKED_BLOCK_N_W), PACKED_BLOCK_N_W)
293+
N_W = N
294+
if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
295+
N_W = tl.cdiv(N_W, 64) * 64
296+
offs_w_n = tl.max_contiguous(tl.multiple_of(offs_w_n % (N_W // W_N_DIVISOR), PACKED_BLOCK_N_W), PACKED_BLOCK_N_W)
294297

295298
if is_x_microscaled:
296299
XMxScale += start_z.to(index_type) * stride_x_mx_z

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def _p_matmul_ogs(
217217
M, K, ExptData, ExptHist, ExptOffs, ExptTileOffs,
218218
EXPT_IS_INNER, X_IS_PADDED, W_IS_PADDED,
219219
BLOCK_M, BLOCK_K, PACKED_BLOCK_K_W, SPLIT_K,
220-
GROUP_M, XCD_SWIZZLE)
220+
GROUP_M, XCD_SWIZZLE, SWIZZLE_MX_VALUE)
221221
off_n = BLOCK_N * pid_n
222222

223223
# Base pointers and offsets.
@@ -347,7 +347,7 @@ def _p_matmul_ogs(
347347
M, K, ExptData, ExptHist, ExptOffs, ExptTileOffs,
348348
EXPT_IS_INNER, X_IS_PADDED, W_IS_PADDED,
349349
BLOCK_M, BLOCK_K, PACKED_BLOCK_K_W, SPLIT_K,
350-
GROUP_M, XCD_SWIZZLE)
350+
GROUP_M, XCD_SWIZZLE, SWIZZLE_MX_VALUE)
351351
off_n1 = pid_n1 * BLOCK_N
352352
else:
353353
tile_id1, expt_id1, start_z1, start_m1, eM1 = tile_id, expt_id, start_z_out, start_m, eM

python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,14 @@ def swizzle_data(self, data):
120120
batch = data.ndim - 2
121121
assert batch >= 0
122122
assert self.mma_version in (2, 3)
123+
# Pre-pad both matrix dims to multiples of 64
124+
*_, M_in, K_in = data.shape
125+
SWIZZLE_ALIGN_M = 64
126+
SWIZZLE_ALIGN_K = 64
127+
pad_m = (SWIZZLE_ALIGN_M - (M_in % SWIZZLE_ALIGN_M)) % SWIZZLE_ALIGN_M
128+
pad_k = (SWIZZLE_ALIGN_K - (K_in % SWIZZLE_ALIGN_K)) % SWIZZLE_ALIGN_K
129+
data = torch.nn.functional.pad(data, (0, pad_k, 0, pad_m))
130+
123131
data = self._maybe_mT(data)
124132
init_shape = data.shape
125133

0 commit comments

Comments
 (0)