Skip to content

Commit 981b0bb

Browse files
authored
[KERNELS] improved twiddling/swizzling for H100 simulated mxfp4 (h/t @lezcano) (#7587)
1 parent 63f3432 commit 981b0bb

File tree

12 files changed

+296
-148
lines changed

12 files changed

+296
-148
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
1111
from triton_kernels.numerics import InFlexData
1212
from triton_kernels.routing import routing
13-
from triton_kernels.target_info import is_hip, get_cdna_version, is_cuda
13+
from triton_kernels.target_info import is_hip, get_cdna_version
1414
from triton_kernels.tensor import convert_layout
15-
from triton_kernels.tensor_details.layout import StridedLayout, BlackwellMXScaleLayout, HopperMXScaleLayout, HopperMXValueLayout
1615
from triton_kernels.tensor import wrap_torch_tensor, FP4
1716
from dataclasses import dataclass
17+
from triton_kernels.tensor_details import layout
1818

1919
if torch.cuda.is_available() and not is_hip():
2020
from triton._C.libtriton import nvidia
@@ -36,8 +36,8 @@ def quantize(w, dtype, dev, **opt):
3636
else:
3737
assert dtype == "mx4", f"{dtype=}"
3838
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
39-
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), opt["value_layout"])
40-
w_scale = convert_layout(wrap_torch_tensor(w_scale), opt["scale_layout"])
39+
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), opt["value_layout"], **opt["value_layout_opts"])
40+
w_scale = convert_layout(wrap_torch_tensor(w_scale), opt["scale_layout"], **opt["scale_layout_opts"])
4141
return w, InFlexData(), w_scale
4242

4343

@@ -101,16 +101,13 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
101101
optg = dict()
102102
opt1 = dict()
103103
opt2 = dict()
104-
if w_dtype == "mx4":
105-
value_layout = StridedLayout
106-
scale_layout = StridedLayout
107-
if is_cuda():
108-
if torch.cuda.get_device_capability()[0] == 9:
109-
value_layout = HopperMXValueLayout
110-
scale_layout = HopperMXScaleLayout
111-
if torch.cuda.get_device_capability()[0] == 10:
112-
scale_layout = BlackwellMXScaleLayout
113-
opt1 = {"value_layout": value_layout, "scale_layout": scale_layout}
104+
if w_dtype == "mx4" and not is_hip():
105+
num_warps = 4 if batch <= 512 else 8
106+
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
107+
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
108+
mx_axis=1, num_warps=num_warps)
109+
opt1 = {"value_layout": value_layout, "value_layout_opts": value_layout_opts, \
110+
"scale_layout": scale_layout, "scale_layout_opts": scale_layout_opts}
114111
opt2 = deepcopy(opt1)
115112
wg, wg_flex, wg_scale = quantize(wg, "bf16", dev, **optg)
116113
w1, w1_flex, w1_scale = quantize(w1, w_dtype, dev, **opt1)

python/triton_kernels/tests/test_matmul.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -328,25 +328,24 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
328328
w_ref = w_ref.squeeze(0).detach().requires_grad_(test_bwd)
329329

330330
if is_mixed_input:
331-
capability_major = torch.cuda.get_device_capability()[0]
332-
w_layout = layout.StridedLayout
333-
w_scale_layout = layout.StridedLayout
331+
mx_axis = w_tri.ndim - 2
332+
# compute layouts
333+
w_layout, w_layout_opts = layout.StridedLayout, dict()
334+
w_scale_layout, w_scale_layout_opts = layout.StridedLayout, dict()
334335
if hbm_swizzling and "float4" in weight_dtype_str:
335-
# weight layout
336-
w_layouts = {9: layout.HopperMXValueLayout}
337-
w_layout = w_layouts.get(capability_major, layout.StridedLayout)
338-
# weight scale layout
339-
w_scales_layouts = {9: layout.HopperMXScaleLayout, 10: layout.BlackwellMXScaleLayout}
340-
w_scale_layout = w_scales_layouts.get(capability_major, layout.StridedLayout)
341-
w_tri, mx_scales_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=-2)
342-
w_ref = upcast_from_mxfp(w_tri, mx_scales_tri, torch.bfloat16, axis=-2)
336+
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=mx_axis)
337+
w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
338+
mx_axis=mx_axis, num_warps=8)
339+
# downcast to mxfp
340+
w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
341+
w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis)
343342
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
344-
w_tri = convert_layout(wrap_torch_tensor(w_tri, w_tri_dtype), w_layout)
345-
mx_scales_tri = convert_layout(wrap_torch_tensor(mx_scales_tri), w_scale_layout)
346-
precision_opt.weight_scale = mx_scales_tri
347-
348-
# if not is_persistent and precision_opt.weight_scale is not None:
349-
# pytest.skip("non-persistent not supported with mxfp")
343+
w_tri = wrap_torch_tensor(w_tri, w_tri_dtype)
344+
w_scale_tri = wrap_torch_tensor(w_scale_tri)
345+
# convert layouts
346+
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
347+
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
348+
precision_opt.weight_scale = w_scale_tri
350349

351350
if test_launch_metadata:
352351

Lines changed: 1 addition & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1 @@
1-
import torch
2-
import pytest
3-
import math
4-
from triton_kernels.testing import assert_equal
5-
from triton_kernels.tensor_details.layout import BlackwellMXScaleLayout, HopperMXScaleLayout, HopperMXValueLayout
6-
7-
8-
@pytest.mark.parametrize(
9-
"shape",
10-
[
11-
(3, 4096, 1024),
12-
(10, 254, 60),
13-
(1, 320, 160),
14-
(2, 16, 512),
15-
(3, 2, 36),
16-
],
17-
)
18-
def test_mxfp_swizzle(shape: tuple[int, ...]):
19-
"""
20-
Test that unswizzle is the inverse of swizzle, after removing padding.
21-
"""
22-
x = torch.randn(shape, device="cuda")
23-
layout = BlackwellMXScaleLayout(shape)
24-
assert_equal(x, layout.unswizzle_data(layout.swizzle_data(x)))
25-
26-
27-
@pytest.mark.parametrize("shape", [(16, 32), (16, 64), (32, 32), (32, 64), (64, 128), (128, 128)])
28-
@pytest.mark.parametrize("trans", [False, True])
29-
@pytest.mark.parametrize("op_idx", [0, 1])
30-
@pytest.mark.parametrize("mma_version", [2, 3])
31-
def test_swizzle_mxfp4_value(shape, trans, op_idx, mma_version):
32-
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
33-
if trans:
34-
x = x.mT
35-
k_dim = 1 - op_idx
36-
if x.shape[k_dim] < 32:
37-
pytest.skip("Not enough elements along K")
38-
layout = HopperMXValueLayout(x.shape, op_idx, mma_version)
39-
res = layout.unswizzle_data(layout.swizzle_data(x))
40-
assert (res == x).all()
41-
42-
43-
@pytest.mark.parametrize("num_warps", [4, 8])
44-
@pytest.mark.parametrize("shape", [(256, 64), (256, 128), (256, 256)])
45-
def test_swizzle_mxfp4_scale(shape, num_warps):
46-
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
47-
layout = HopperMXScaleLayout(x.shape, num_warps=num_warps)
48-
res = layout.unswizzle_data(layout.swizzle_data(x))
49-
assert (res[:shape[0], :shape[1]] == x).all()
50-
51-
52-
def test_unswizzle_mxfp4_value_golden_value():
53-
shape = (16, 32)
54-
x = torch.arange(math.prod(shape)).view(shape).to(torch.uint8)
55-
layout = HopperMXValueLayout(x.shape, op_idx=1, mma_version=3)
56-
res = layout.swizzle_data(x)
57-
# Thread 0
58-
assert res[0, 0:16].tolist() == [0, 0, 4, 4, 8, 8, 12, 12, 16, 16, 20, 20, 24, 24, 28, 28]
59-
# Thread 1
60-
assert res[0, 16:32].tolist() == [1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21, 25, 25, 29, 29]
1+
# TODO: add tests for non-layout parts of tensor class

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,8 @@ def matmul_ogs(x, w, bias,
482482
w_scale_strides = w_scale.stride() if has_mx and not w_scale_has_tma else (None, None, None)
483483
if len(w_scale_strides) == 2:
484484
w_scale_strides = (0, ) + w_scale_strides
485+
# if routing_data.expt_hist is not None:
486+
# print(opt_flags)
485487
# launch kernel
486488
kernels = get_kernels(epilogue.specs, fused_activation.specs)
487489
(kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)](
@@ -532,8 +534,8 @@ def matmul_ogs(x, w, bias,
532534
**opt_flags.target_kernel_kwargs)
533535
# post-processing
534536
out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_token_offs_raw,
535-
num_indx, precision_config, routing_data,
536-
postprocessing_features, memory, fused_postprocess_activation, epilogue)
537+
num_indx, precision_config, routing_data,
538+
postprocessing_features, memory, fused_postprocess_activation, epilogue)
537539
# remove split-k
538540
out = out.squeeze(0)
539541
if not is_input_batched:

python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def convert_dtype(dtype):
8484
# suffix = "" if not mode else "_o" + (''.join(mode))
8585
# if base_name.startswith("_p"):
8686
# suffix += "_ptma"
87-
return f"{base_name}_{layouts}_{dtypes}_{blocks}"
87+
return f"cutlass_{base_name}_{layouts}_{dtypes}_{blocks}"
8888

8989
return matmul_repr
9090

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import triton.language as tl
33
from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
44
from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
5-
from triton_kernels.tensor_details.layout_details.hopper_value import unswizzle_mxfp4_value_hopper
5+
from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
66
from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale
77
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
88
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
@@ -250,18 +250,24 @@ def _matmul_ogs(
250250
w_scales = unswizzle_mx_scale_bw(tl.load(MxScalePtrs))
251251
elif SWIZZLE_MX_SCALE == "HOPPER_SCALE":
252252
# Handshake with the swizzling code
253-
tl.static_assert(tl.extra.cuda.num_warps() == 8, "Only 8 warps are supported for Hopper swizzling. Got %d" % tl.extra.cuda.num_warps())
254-
w_scales = unswizzle_mxfp4_scale_hopper(tl.load(MxScalePtrs), num_warps=8)
253+
num_warps: tl.constexpr = tl.extra.cuda.num_warps()
254+
w_scales = unswizzle_mxfp4_scale_hopper(tl.load(MxScalePtrs), mx_axis=1, num_warps=num_warps)
255255
else:
256256
w_scales = tl.load(MxScalePtrs, mask=mask_k_scale[None, :], other=0.0)
257257

258258
if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
259259
# Handshake with the swizzling code
260-
w = unswizzle_mxfp4_value_hopper(w, op_idx=1, mma_version=3)
261-
mma_version: tl.constexpr = 3 if w.shape[1] >= 64 else 2
262-
tl.static_assert(mma_version == 3, "Only mma_version 3 is supported for Hopper swizzling")
263-
264-
acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, mx_format, acc=acc, fast_math=True)
260+
tl.static_assert(x_format == "bf16")
261+
tl.static_assert(mx_format == "e2m1")
262+
w = mxfp4_to_bf16_triton(w.trans(), w_scales, 1)
263+
tl.static_assert(w.dtype == tl.bfloat16)
264+
acc = acc.trans()
265+
x = x.trans()
266+
# w = w.trans()
267+
acc = tl.dot(w, x, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
268+
acc = acc.trans()
269+
else:
270+
acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, mx_format, acc=acc, fast_math=True)
265271
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
266272
MxScalePtrs += (MX_SCALE_BLOCK_K // 4 * SPLIT_K) * stride_mx_k
267273
else:

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from triton_kernels.target_info import get_cdna_version
44
import torch
55
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
6-
from ..tensor import get_layout
76

87
# fmt: off
98

@@ -157,12 +156,10 @@ def make_default_opt_flags_nvidia(
157156
elif enforce_bitwise_invariance:
158157
block_m = 128
159158
else:
160-
block_m = max(64, min(triton.next_power_of_2(tokens_per_expt), 128))
159+
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
161160
# block n
162161
arch = None
163162
block_n = opt_flags_nvidia.compute_block_n(n, arch, precision_config)
164-
if precision_config.weight_scale is not None and get_layout(precision_config.weight_scale).name == "HOPPER_SCALE":
165-
block_n = 256
166163
# is_persistent
167164
grid_size = opt_flags_nvidia.compute_grid_size(routing_data, m, n, block_m, block_n)
168165
n_sms = torch.cuda.get_device_properties(0).multi_processor_count
@@ -177,7 +174,7 @@ def make_default_opt_flags_nvidia(
177174
if constraints.get("block_k", None) is not None:
178175
block_k = constraints["block_k"]
179176
else:
180-
block_k = opt_flags_nvidia.compute_block_k(k, is_persistent, lhs_dtype, rhs_dtype, precision_config)
177+
block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config)
181178
# split_k
182179
if constraints.get("split_k", None) is not None:
183180
split_k = constraints["split_k"]
@@ -219,8 +216,7 @@ def make_default_opt_flags_nvidia(
219216
else:
220217
fused_scatter = can_use_fused_scatter and split_k == 1
221218
# Handshake with the HBM swizzling
222-
hopper_swizzling = precision_config.weight_scale is not None and get_layout(precision_config.weight_scale).name == "HOPPER_SCALE"
223-
num_warps = 8 if hopper_swizzling else opt_flags_nvidia.compute_num_warps(block_m, block_n)
219+
num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, precision_config)
224220
ret = OptFlags(
225221
block_m=block_m,
226222
block_n=block_n,

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import torch
22
import triton
33
from triton_kernels import target_info
4-
from triton_kernels.tensor import bitwidth, FP4
4+
from triton_kernels.tensor import get_layout, bitwidth, FP4
5+
from triton_kernels.tensor_details.layout import HopperMXScaleLayout
56
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
67

78

@@ -16,19 +17,20 @@ def compute_grid_size(routing_data, m, n, block_m, block_n):
1617

1718
def compute_block_n(n: int, arch, precision_config):
1819
# block_n:
19-
block_n = max(16, min(128, triton.next_power_of_2(n)))
20-
# On Ampere and Hopper, handshake with swizzle_mxfp4_scale_hopper
21-
if precision_config.max_num_imprecise_acc is None and n > 128:
22-
block_n = 256
23-
return block_n
20+
layout = get_layout(precision_config.weight_scale)
21+
if isinstance(layout, HopperMXScaleLayout) and layout.num_warps == 4:
22+
return 128
23+
elif precision_config.max_num_imprecise_acc is None and n > 128:
24+
return 256
25+
else:
26+
return max(16, min(128, triton.next_power_of_2(n)))
2427

2528

26-
def compute_block_k(k: int | None, is_persistent: bool, lhs_dtype, rhs_dtype, precision_config):
29+
def compute_block_k(m: int, k: int | None, is_persistent: bool, lhs_dtype, rhs_dtype, precision_config):
2730
lhs_width = bitwidth(lhs_dtype)
2831
rhs_width = bitwidth(rhs_dtype)
2932
# block_k needs to match the cacheline size (1024 bits)
3033
block_k = int(1024 // min(lhs_width, rhs_width))
31-
# TODO: revisit when Triton is better for H100 + MXFP4
3234
has_native_mxfp = target_info.cuda_capability_geq(10, 0)
3335
if rhs_width == 4 and not has_native_mxfp:
3436
block_k = 128
@@ -52,7 +54,10 @@ def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
5254
return split_k
5355

5456

55-
def compute_num_warps(block_m, block_n):
57+
def compute_num_warps(block_m, block_n, precision_config):
58+
layout = get_layout(precision_config.weight_scale)
59+
if isinstance(layout, HopperMXScaleLayout):
60+
return layout.num_warps
5661
return max(block_m * block_n // 4096, 4)
5762

5863

python/triton_kernels/triton_kernels/tensor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,9 @@ def sum(self, partials_block_size):
172172
return sum_bitmatrix_rows(self, out_ret, partials_block_size)
173173

174174

175-
def get_layout(tensor: torch.Tensor | Tensor):
175+
def get_layout(tensor: torch.Tensor | Tensor | None):
176+
if tensor is None:
177+
return None
176178
if isinstance(tensor, Tensor):
177179
return tensor.storage.layout
178180
return StridedLayout
@@ -186,11 +188,11 @@ def wrap_torch_tensor(torch_tensor, dtype=None):
186188
return Tensor(Storage(torch_tensor), dtype=dtype, shape=shape)
187189

188190

189-
def convert_layout(tensor: Tensor, layout_cls: Type[Layout]):
191+
def convert_layout(tensor: Tensor, layout_cls: Type[Layout], **layout_kwargs):
190192
assert isinstance(tensor, Tensor)
191193
old_storage = tensor.storage
192194
old_data = old_storage.layout.unswizzle_data(old_storage.data)
193-
new_layout = layout_cls(old_data.shape)
195+
new_layout = layout_cls(old_data.shape, **layout_kwargs)
194196
new_data = new_layout.swizzle_data(old_data)
195197
attrs = {k.name: getattr(tensor, k.name) for k in fields(tensor) if k.name != "storage"}
196198
return Tensor(Storage(new_data, new_layout), **attrs)

python/triton_kernels/triton_kernels/tensor_details/layout.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .layout_details.hopper_scale import HopperMXScaleLayout
44
from .layout_details.hopper_value import HopperMXValueLayout
55
from .layout_details.strided import StridedLayout
6+
from ..target_info import cuda_capability_geq
67

78
__all__ = [
89
"Layout",
@@ -11,3 +12,21 @@
1112
"HopperMXValueLayout",
1213
"StridedLayout",
1314
]
15+
16+
17+
def make_default_matmul_mxfp4_w_layout(mx_axis: int):
18+
if cuda_capability_geq(10):
19+
return StridedLayout, dict()
20+
elif cuda_capability_geq(9):
21+
return HopperMXValueLayout, {"mx_axis": mx_axis}
22+
else:
23+
return StridedLayout, dict()
24+
25+
26+
def make_default_matmul_mxfp4_w_scale_layout(mx_axis: int, num_warps: int = 8):
27+
if cuda_capability_geq(10):
28+
return BlackwellMXScaleLayout, dict()
29+
elif cuda_capability_geq(9):
30+
return HopperMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps}
31+
else:
32+
return StridedLayout, dict()

0 commit comments

Comments
 (0)