Skip to content

Commit b010cf1

Browse files
authored
[KERNELS] fix mxfp4 constraints and split k in persistent matmul (#7119)
- strides need to be a multiple of 32 bytes for the fp4 tma - inner dim needs to be a multiple of 128 for the fp4 tma. previous code only checked `w.shape[-1]` but this is the wrong axis when SWAP_XW mode is used - split k offsets and masks were computed incorrectly - update test shapes so they actually exercise the persistent tma codepath
1 parent 8a6dfa5 commit b010cf1

File tree

3 files changed

+35
-24
lines changed

3 files changed

+35
-24
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -193,16 +193,17 @@ class Case:
193193
Case(300, 400, 400, "batched", "bfloat16", "mxfloat8_e5m2", 32, 4),
194194
Case(1000, 700, 2, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2),
195195
Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
196-
Case(1000, 704, 800, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
197-
Case(1000, 704, 800, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1),
198-
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9),
199-
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
200-
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2),
201-
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
196+
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
197+
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
198+
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1),
199+
Case(1000, 704, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9),
200+
Case(1000, 704, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
201+
Case(1000, 704, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2),
202+
Case(1000, 704, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
202203
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4),
203204
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
204-
Case(300, 400, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4),
205-
Case(300, 400, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True),
205+
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4),
206+
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True),
206207
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4),
207208
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True),
208209
# AMD

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,8 @@ def mx_can_use_tma(mx_ctx: MicroscalingCtx):
250250

251251
def can_use_persistent_tma(x, w, gather_indx, precision_config):
252252
mx_ctx = precision_config.mx_ctx
253+
is_mxfp4 = mx_ctx.weight_scale is not None and w.dtype == torch.uint8
254+
weight_stride_req = 32 if is_mxfp4 else 16
253255
return (
254256
# TMA requires CUDA 9.0, last dim contiguous, and multiple of 16-byte strides otherwise.
255257
target_info.cuda_capability_geq(9, 0) and
@@ -258,14 +260,10 @@ def can_use_persistent_tma(x, w, gather_indx, precision_config):
258260
x.stride(1) * x.element_size() % 16 == 0 and x.stride(2) == 1
259261
) and (
260262
# Check W is either transposed or non-transposed, and with required stride.
261-
(w.stride(1) * w.element_size() % 16 == 0 and w.stride(2) == 1) or
262-
(w.stride(2) * w.element_size() % 16 == 0 and w.stride(1) == 1)
263+
(w.stride(1) * w.element_size() % weight_stride_req == 0 and w.stride(2) == 1) or
264+
(w.stride(2) * w.element_size() % weight_stride_req == 0 and w.stride(1) == 1)
263265
) and (
264266
mx_ctx.weight_scale is None or mx_can_use_tma(mx_ctx)
265-
) and (
266-
# MFXP4 tma requires 128 elements on the inner dim.
267-
# MFXP4 is represented as packed uint8.
268-
w.dtype != torch.uint8 or w.shape[-1] % 128 == 0
269267
)
270268
# compiler crash ?
271269
and (x.dtype.itemsize <= 1 or w.dtype != torch.uint8)

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,24 @@ def _update_tensor_desc(desc, ptr, shape=None):
3535
)
3636

3737
@triton.jit
38-
def _make_tensor_desc(ptr, shape, strides, block_shape, transpose: tl.constexpr = False):
38+
def _multiple_of(a, b):
39+
return tl.cdiv(a, b) * b
40+
41+
@triton.jit
42+
def _make_tensor_desc(ptr, shape, strides, block_shape, transpose: tl.constexpr = False, pad_inner_shape: tl.constexpr = 1):
3943
tl.static_assert(len(shape) == len(strides))
4044
tl.static_assert(len(strides) == len(block_shape))
4145
if transpose:
4246
return tl.make_tensor_descriptor(
4347
ptr,
44-
shape=shape[:-2] + [shape[-1], shape[-2]],
48+
shape=shape[:-2] + [shape[-1], _multiple_of(shape[-2], pad_inner_shape)],
4549
strides=strides[:-2] + [strides[-1], tl.constexpr(1)],
4650
block_shape=block_shape[:-2] + [block_shape[-1], block_shape[-2]],
4751
)
4852
else:
4953
return tl.make_tensor_descriptor(
5054
ptr,
51-
shape=shape,
55+
shape=shape[:-1] + [_multiple_of(shape[-1], pad_inner_shape)],
5256
strides=strides[:-1] + [tl.constexpr(1)],
5357
block_shape=block_shape,
5458
)
@@ -235,12 +239,20 @@ def _p_matmul_ogs(
235239
block_shape=[BLOCK_M, BLOCK_K]
236240
)
237241

242+
# Pad the inner shape to 128 for mxfp4 weights; TMA requires this when the compiler uses CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B.
243+
# This technically makes the shape masking incorrect, but it's fine because:
244+
# - When the N dim is padded, the scales will be masked to 0.
245+
# - When the K dim is padded, the activations we perform tl.dot with will be masked to 0.
246+
# Note: the scales can't be relied on for zeroing in this case, because they apply to groups
247+
# of 32 elements in the K dimension.
248+
w_pad_inner_shape = 128 if is_microscaled_format and W.dtype.element_ty == tl.uint8 else 1
238249
w_desc = _make_tensor_desc(W,
239250
shape=[N_EXPTS_TOT if ExptData is not None else batch_size,
240251
(K + W_PACK_DIVISOR - 1) // W_PACK_DIVISOR, N],
241252
strides=[stride_w_e, stride_w_k, stride_w_n],
242253
block_shape=[1, PACKED_BLOCK_K_W, BLOCK_N],
243-
transpose=W_TRANSPOSE)
254+
transpose=W_TRANSPOSE,
255+
pad_inner_shape=w_pad_inner_shape)
244256

245257
if is_microscaled_format:
246258
PackedK = (K + MX_PACK_DIVISOR - 1) // MX_PACK_DIVISOR
@@ -320,7 +332,7 @@ def _p_matmul_ogs(
320332

321333
if SPLIT_K > 1:
322334
offs_mx_k += MX_SCALE_BLOCK_K * pid_k
323-
offs_mx_inner += PACKED_MX_BLOCK * pid_k
335+
offs_mx_inner += (MX_SCALE_BLOCK_K // 4) * pid_k * stride_mx_k
324336

325337
if X_USE_LOAD_TMA:
326338
if ExptData is None:
@@ -357,13 +369,13 @@ def _p_matmul_ogs(
357369
else:
358370
XPtrs = XBase + offs_x_m + offs_x_k
359371
XBase += BLOCK_K * SPLIT_K * stride_x_k
372+
mask_k = tl.arange(0, BLOCK_K) < K - off_k
360373
if EVEN_K:
361374
if SPLIT_K > 1:
362-
x = tl.load(XPtrs, mask=off_k < K, other=0.0)
375+
x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
363376
else:
364377
x = tl.load(XPtrs)
365378
else:
366-
mask_k = tl.arange(0, BLOCK_K) < K - off_k
367379
x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
368380

369381
w = _load_tensor_desc(w_desc, [expt_id, off_k_w, off_n], transpose=W_TRANSPOSE)
@@ -381,17 +393,17 @@ def _p_matmul_ogs(
381393
w_scales = unswizzle_mx_scale_bw(tl.load(MxPtrs))
382394
else:
383395
MxPtrs = MxScale + expt_id.to(index_type) * stride_mx_e + offs_mx_k.to(index_type)[None, :] * stride_mx_k + offs_w_n.to(index_type)[:, None] * stride_mx_n + ki * MX_SCALE_BLOCK_K * SPLIT_K * stride_mx_k
396+
mask_k = offs_mx_k < tl.cdiv(K - off_k, MX_PACK_DIVISOR)
384397
if EVEN_K:
385398
if SPLIT_K > 1:
386-
w_scales = tl.load(MxPtrs, mask=off_k < K, other=0.0)
399+
w_scales = tl.load(MxPtrs, mask=mask_k[None, :], other=0.0)
387400
else:
388401
w_scales = tl.load(MxPtrs)
389402
else:
390-
mask_k = offs_mx_k < tl.cdiv(K - off_k, MX_PACK_DIVISOR)
391403
w_scales = tl.load(MxPtrs, mask=mask_k[None, :], other=0.0)
392404

393405
elif SWIZZLE_MX_SCALE == "BLACKWELL":
394-
w_scales = mx_desc.load([expt_id, off_n // 128, ki * (MX_SCALE_BLOCK_K // 4 * SPLIT_K), 0, 0])
406+
w_scales = mx_desc.load([expt_id, off_n // 128, pid_k * MX_SCALE_BLOCK_K // 4 + ki * (MX_SCALE_BLOCK_K // 4 * SPLIT_K), 0, 0])
395407
w_scales = w_scales.reshape((w_scales.shape[1], w_scales.shape[2] * 32 * 4 * 4))
396408
w_scales = unswizzle_mx_scale_bw(w_scales)
397409
else:

0 commit comments

Comments
 (0)