Skip to content

Commit f81125b

Browse files
[mxfp] handle w_scale w/o swizzle correctly (#8652)
In practice, we can't support w_scale with column-wise strided layout, since we will divide the reduction dim by 32 then it needs to be a multiple of 16 for TMA. So, we disable TMA (and persistent kernel) for this case. Added a test case for this. Before this PR the test case led to ``` E triton.compiler.errors.CompilationError: at 227:26: E w_scales = w_scales.reshape((w_scales.shape[1], w_scales.shape[2] * w_scales.shape[-2] * w_scales.shape[-1])) E w_scales = unswizzle_mx_scale_bw(w_scales) E else: E w_scales = WMxScale.load([expt_id, off_k_mx, off_n]) E w_scales = tl.reshape(w_scales, *w_scales.shape[1:]).T E E # --- update accumulator --- E if is_w_microscaled: E if SWAP_XW: E acc = tl.dot_scaled(w.T, w_scales, w_format, x.T, x_scales, x_format, acc=acc, fast_math=True) E else: E acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, w_format, acc=acc, fast_math=True) E ^ E rhs_scale must be a tensor of shape [256, 4]. Got ['4', '256'] ``` The way ``make_dense_tma`` was checking if it was called for scale was also ambiguous. Previously, it assumed for ``StridedLayout`` it's not scale which is wrong. # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent e93fc76 commit f81125b

File tree

4 files changed

+29
-19
lines changed

4 files changed

+29
-19
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ class Case:
233233
Case(16, 16, 1000, "batched", "float8_e5m2", "float8_e5m2", 5, 1, split_k=None),
234234
Case(16, 16, 2048, "batched", "float8_e5m2", "float8_e5m2", 6, 1, split_k=5),
235235
# mx types:
236+
Case(1, 1024, 1024, "plain", "bfloat16", "mxfloat8_e4m3fn", 1, 1),
236237
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1),
237238
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True),
238239
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True, epilogue_subtile=4),

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,12 @@ def matmul_ogs(x, w, bias,
416416
# unaligned access.
417417
(inner_routing_data is None or w.stride(-1) == 1 or inner_routing_data.w_is_padded)
418418
)
419+
if w_scale is not None and isinstance(w_scale.storage.layout, StridedLayout) and w_scale.storage.data.stride()[-1] != 1:
420+
# In this case, we need to transpose w_scale. Then the reduction dim
421+
# becomes the last dim that will be divided by 32. This to be a multiple
422+
# of 16 to be TMA-compliant requires block_k to be a multiple of 512,
423+
# which is too big.
424+
can_use_tma = False
419425
has_gather_tma = has_gather and target_info.has_tma_gather()
420426
# hopper w/ mxfp4 doesn't support TMA
421427
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
@@ -526,14 +532,23 @@ def matmul_ogs(x, w, bias,
526532
w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n], "dense") if w_has_tma else w_storage.data
527533
# create tma descriptor for w_scale
528534
w_scale_has_tma = opt_flags.is_persistent and w_scale is not None
535+
# When stride(-2) == stride(-1) == 1, it's ambiguous whether W is transposed
536+
# (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose
537+
# is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs.
538+
# w_transpose = w_storage.data.stride()[-1] != 1
529539
w_transpose = w_storage.data.stride()[-2] == 1
530540
if w_scale_has_tma:
531541
w_scale_storage = w_scale.storage
532-
w_scale_tma_block_size = [opt_flags.block_n, opt_flags.block_k] if w_transpose else [opt_flags.block_k, opt_flags.block_n]
542+
scale_block_k = opt_flags.block_k // int(MXFP_BLOCK_SIZE)
543+
# cancel out the transpose done inside make_tma since
544+
# BlackwellMXScaleLayout.swizzle_block_shape expects block_shape[1] is
545+
# the reduction dimension.
546+
w_scale_tma_block_size = [opt_flags.block_n, scale_block_k] if w_transpose and w_scale.storage.layout.name == "BLACKWELL_SCALE" else [scale_block_k, opt_flags.block_n]
533547
if isinstance(w_scale.storage.layout, StridedLayout):
548+
assert w_scale_storage.data.stride()[-1] == 1, "w_scale should be contiguous with StridedLayout"
534549
w_scale_storage = _canonicalize_storage(w_scale.storage, 3, None)
535550
w_scale_tma_block_size = [1] + w_scale_tma_block_size
536-
w_scale_tensor_or_tma = w_scale_storage.make_tma(w_scale_tma_block_size, "dense")
551+
w_scale_tensor_or_tma = w_scale_storage.make_tma(w_scale_tma_block_size, "dense", is_scale=True)
537552
else:
538553
w_scale_tensor_or_tma = w_scale
539554
# canonicalize strides
@@ -546,10 +561,6 @@ def matmul_ogs(x, w, bias,
546561
out_matmul_scale_strides = (0, ) * (4 - len(out_matmul_scale_strides)) + out_matmul_scale_strides
547562
# launch kernel
548563
kernels = specializations.get(epilogue=epilogue.specs, activation=matmul_fused_activation.specs)
549-
# When stride(-2) == stride(-1) == 1, it's ambiguous whether W is transposed
550-
# (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose
551-
# is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs.
552-
# w_transpose = w_storage.data.stride()[-1] != 1
553564
if gather_indx is not None:
554565
gather_src_indx = torch.div(gather_indx.src_indx, routing_data.n_expts_act, rounding_mode='trunc')
555566
fused_comm_kwargs = {

python/triton_kernels/triton_kernels/tensor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,28 +47,28 @@ def is_tma_compliant(self):
4747
compliant = [strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim]
4848
return all(compliant)
4949

50-
def make_dense_tma(self, block_shape):
50+
def make_dense_tma(self, block_shape, is_scale):
5151
strides = list(self.data.stride())
5252
shape = list(self.data.shape)
5353
transpose = strides[-1] != 1
5454
if transpose:
55+
# Need to transpose since tensor descriptor expects strides except for the last dimension 16-byte aligned
56+
# https://github.com/triton-lang/triton/blob/e5e0081db3335e7755e2c67c784cb1c92769812f/python/triton/tools/tensor_descriptor.py#L26
5557
block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]]
5658
shape = shape[:-2] + [shape[-1], shape[-2]]
5759
strides = strides[:-2] + [strides[-1], strides[-2]]
58-
if self.data.dtype == torch.uint8 and (self.layout.name is None or "_SCALE" not in self.layout.name):
60+
if self.data.dtype == torch.uint8 and not is_scale:
5961
indx = strides.index(1)
6062
block_shape[indx] = block_shape[indx] // 2
61-
if isinstance(self.layout, BlackwellMXValueLayout):
62-
if shape[-1] % 128 != 0:
63-
raise ValueError(
64-
"inner shape need to be multiple of 128 for mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs."
65-
)
63+
if isinstance(self.layout, BlackwellMXValueLayout) and shape[-1] % 128 != 0:
64+
raise ValueError(
65+
"inner shape need to be multiple of 128 for mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs.")
6666
block_shape = self.layout.swizzle_block_shape(block_shape)
6767
return TensorDescriptor(self.data, shape, strides, block_shape)
6868

69-
def make_tma(self, block_shape, mode):
69+
def make_tma(self, block_shape, mode, is_scale=False):
7070
if mode in ["dense", "gather", "scatter"]:
71-
return self.make_dense_tma(block_shape)
71+
return self.make_dense_tma(block_shape, is_scale)
7272
assert mode == "ragged"
7373
ragged_dim = len(self.data.shape) - 2
7474
return create_ragged_descriptor(self.data, block_shape, ragged_dim=ragged_dim)

python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def swizzle_data(self, data):
3434
data = data.reshape(self.B, self.N_pad // self.ALIGN_N, self.ALIGN_N // 32, 32, self.K_pad // self.SWIZZLE_K,
3535
self.SWIZZLE_K)
3636
data = data.transpose(2, 4).contiguous()
37-
data = data.view(1, self.B * self.N_pad // 128, self.K_pad // 4, 2, 256)
37+
data = data.view(1, self.B * self.N_pad // 128, self.K_pad // self.SWIZZLE_K, 2, 256)
3838
return data
3939

4040
def unswizzle_data(self, data):
@@ -46,10 +46,8 @@ def unswizzle_data(self, data):
4646
return data[..., :self.K, :self.N]
4747

4848
def swizzle_block_shape(self, block_shape):
49-
MX_PACK_DIVISOR = 32
50-
MX_SCALE_BLOCK_K = block_shape[1] // MX_PACK_DIVISOR
5149
assert block_shape[0] >= 128, f"{block_shape[0]=} must be >= 128"
52-
return [1, block_shape[0] // 128, MX_SCALE_BLOCK_K // 4, 2, 256]
50+
return [1, block_shape[0] // 128, block_shape[1] // 4, 2, 256]
5351

5452

5553
@triton.jit

0 commit comments

Comments
 (0)