Skip to content

Commit 08973b1

Browse files
authored
[BENCH] Implement value and scale swizzling for Hopper bf16xmxfp4 (#6833)
This improves utilisation from 28% to 37%. Edit with triton-lang/triton#6812 we hit 49% util. We don't try to optimise the upcasting sequence yet but we could do that if we are willing to add a global scale via a packing/unpacking sequence by Scott Gray in a future PR. We also add tests for quantisation in general.
1 parent 971a52a commit 08973b1

File tree

9 files changed

+923
-287
lines changed

9 files changed

+923
-287
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from pathlib import Path
2+
from copy import deepcopy
23
import matplotlib.pyplot as plt
34
import json
45
import triton.profiler as proton
56
import torch
7+
import triton_kernels
68
import triton_kernels.swiglu
7-
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
9+
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, SwizzlingType
810
from triton_kernels.matmul_ogs import MicroscalingCtx, matmul_ogs, PrecisionConfig, FlexCtx
911
from triton_kernels.numerics import InFlexData
1012
from triton_kernels.routing import routing
@@ -61,10 +63,14 @@ def quantize(w, dtype, dev, **opt):
6163
else:
6264
assert dtype == "mx4", f"{dtype=}"
6365
swizzle_mx_scale = opt["swizzle_mx_scale"]
66+
swizzle_mx_value = opt["swizzle_mx_value"]
6467
swizzle_axis = 2 if swizzle_mx_scale else None
6568
w = w.to(torch.bfloat16)
66-
w, mx_scales, weight_scale_shape = downcast_to_mxfp(w, torch.uint8, axis=1, swizzle_axis=swizzle_axis)
67-
return w, InFlexData(), MicroscalingCtx(weight_scale=mx_scales, swizzle_mx=swizzle_mx_scale,
69+
w, mx_scales, weight_scale_shape = downcast_to_mxfp(w, torch.uint8, axis=1, swizzle_axis=swizzle_axis,
70+
swizzle_scale=swizzle_mx_scale,
71+
swizzle_value=swizzle_mx_value)
72+
return w, InFlexData(), MicroscalingCtx(weight_scale=mx_scales, swizzle_scale=swizzle_mx_scale,
73+
swizzle_value=swizzle_mx_value,
6874
actual_weight_scale_shape=weight_scale_shape)
6975

7076

@@ -73,6 +79,7 @@ class PerfData:
7379
time: float
7480
flops: float
7581
bytes: float
82+
bitwidth: int
7683

7784
@property
7885
def tflops(self):
@@ -92,8 +99,9 @@ def opint(self):
9299
def util(self) -> float:
93100
if SPECS is None:
94101
return 0.0
102+
assert self.bitwidth in (8, 16)
95103

96-
peak_flops = max(SPECS["MAX_TFLOPS8"], SPECS.get("MAX_TFLOPS16", 0))
104+
peak_flops = SPECS["MAX_TFLOPS8"] if self.bitwidth == 8 else SPECS["MAX_TFLOPS16"]
97105
min_t_flop = self.flops / peak_flops * 1e-3 # ns → µs
98106
min_t_bw = self.bytes / SPECS["MAX_TBPS"] * 1e-3
99107
return max(min_t_flop, min_t_bw) / self.time
@@ -116,8 +124,21 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
116124

117125
# -- numerics --
118126
optg = dict()
119-
opt1 = {"swizzle_mx_scale": True} if w_dtype == "mx4" else dict()
120-
opt2 = {"swizzle_mx_scale": True} if w_dtype == "mx4" else dict()
127+
opt1 = dict()
128+
opt2 = dict()
129+
if w_dtype == "mx4" and not is_hip():
130+
if torch.cuda.get_device_capability()[0] < 9:
131+
# NYI for Ampere
132+
swizzle_mx_value = None
133+
swizzle_mx_scale = None
134+
elif torch.cuda.get_device_capability()[0] < 10:
135+
swizzle_mx_value = SwizzlingType.HOPPER
136+
swizzle_mx_scale = SwizzlingType.HOPPER
137+
else:
138+
swizzle_mx_value = None
139+
swizzle_mx_scale = SwizzlingType.BLACKWELL
140+
opt1 = {"swizzle_mx_value": swizzle_mx_value, "swizzle_mx_scale": swizzle_mx_scale}
141+
opt2 = deepcopy(opt1)
121142
wg, wg_flex, wg_mx = quantize(wg, "bf16", dev, **optg)
122143
w1, w1_flex, w1_mx = quantize(w1, w_dtype, dev, **opt1)
123144
w2, w2_flex, w2_mx = quantize(w2, w_dtype, dev, **opt2)
@@ -165,7 +186,7 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
165186
# TODO: proton should really be recording that in the json instead of
166187
# relying on the user to aggregate
167188
time = sum(x["metrics"].get("time (ns)", 0) for x in data[0]["children"])
168-
return PerfData(time, flops, bytes)
189+
return PerfData(time, flops, bytes, x_dtype.itemsize * 8)
169190

170191

171192
def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP=1, EP=1, name="",

python/triton_kernels/tests/test_matmul.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import pytest
33
import torch
44
from typing import Union
5-
# benchmarking utilities
65
# routing utilities
76
from triton_kernels.routing import routing
87
# matmul utilities
@@ -12,7 +11,7 @@
1211
from triton_kernels.matmul_ogs import matmul_ogs, matmul_ogs_torch
1312
# numerics utilities
1413
from triton_kernels.numerics import InFlexData, OutFlexData
15-
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
14+
from triton_kernels.numerics_details.mxfp import SwizzlingType, downcast_to_mxfp, upcast_from_mxfp
1615
# testing utilities
1716
from triton_kernels.testing import assert_close, compute_actual_scale
1817
# target-specific utilities
@@ -139,7 +138,7 @@ class Case:
139138
n_expts_act: int = 1
140139
n_expt_shards: int = 1
141140
split_k: int = 1
142-
swizzle_mx_scale: bool = False
141+
hbm_swizzling: bool = False
143142
epilogue_subtile: Union[bool, None] = None
144143

145144

@@ -174,25 +173,28 @@ class Case:
174173
Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2, split_k=9),
175174
# mx types:
176175
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4),
176+
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
177177
Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2),
178+
Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
178179
Case(1000, 700, 700, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9),
180+
Case(1000, 512, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
179181
Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4),
180-
Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4, swizzle_mx_scale=True),
182+
Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
181183
Case(300, 400, 400, "batched", "bfloat16", "mxfloat8_e5m2", 32, 4),
182184
Case(1000, 700, 2, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2),
183-
Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, swizzle_mx_scale=True),
184-
Case(1000, 704, 800, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, swizzle_mx_scale=True),
185-
Case(1000, 704, 800, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, swizzle_mx_scale=False),
186-
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, swizzle_mx_scale=False),
187-
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, swizzle_mx_scale=True),
188-
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, swizzle_mx_scale=False),
189-
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, swizzle_mx_scale=True),
190-
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, swizzle_mx_scale=False),
191-
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, swizzle_mx_scale=True),
192-
Case(300, 400, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, swizzle_mx_scale=False),
193-
Case(300, 400, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, swizzle_mx_scale=True),
194-
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, swizzle_mx_scale=False),
195-
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, swizzle_mx_scale=True),
185+
Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
186+
Case(1000, 704, 800, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
187+
Case(1000, 704, 800, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1),
188+
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9),
189+
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
190+
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2),
191+
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
192+
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4),
193+
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
194+
Case(300, 400, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4),
195+
Case(300, 400, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True),
196+
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4),
197+
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True),
196198
# AMD
197199
Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"),
198200
Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 3, 1),
@@ -214,8 +216,8 @@ class Case:
214216
@pytest.mark.parametrize("has_y_gammas", [False, True])
215217
@pytest.mark.parametrize("is_persistent", [False, True])
216218
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas, is_persistent, n_expts_tot,
217-
n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, swizzle_mx_scale,
218-
epilogue_subtile, device):
219+
n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile,
220+
device):
219221
# TODO: remove when Triton FP8 supports proper RTNE
220222
if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9:
221223
pytest.skip("Float8 not tested on A100")
@@ -229,11 +231,22 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
229231
pytest.skip("float8 x mx not supported with cuda capability < 10")
230232
if fused_scatter and split_k > 1:
231233
pytest.skip("fused scatter scratchpad not supported with split_k")
234+
if hbm_swizzling:
235+
if is_hip():
236+
pytest.skip("NYI. HBM swizzling just implemented for CUDA.")
237+
if torch.cuda.get_device_capability()[0] < 9:
238+
pytest.skip("NYI. Ampere swizzling.")
239+
if torch.cuda.get_device_capability()[0] < 10:
240+
if "mxfloat4" not in weight_dtype_str:
241+
pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.")
242+
if k % 64 != 0 or n % 64 != 0:
243+
# Automatic padding not implemented for Hopper swizzle
244+
pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")
232245

233246
torch.manual_seed(0)
234247

235248
block_k = None
236-
if is_persistent and weight_dtype_str.startswith("mx") and not torch.cuda.get_device_capability()[0] >= 10:
249+
if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10:
237250
# Override block_k for testing correctness. The default is temporarily 128 for
238251
# performance reasons which doesn't work with persistent matmul.
239252
# TODO: revisit when Triton is better for H100 + MXFP4
@@ -273,12 +286,27 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
273286
x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt)
274287

275288
if is_mixed_input:
276-
swizzle_axis = 2 if swizzle_mx_scale else None
289+
if hbm_swizzling:
290+
swizzle_axis = 2
291+
if torch.cuda.get_device_capability()[0] < 10:
292+
swizzle_value = SwizzlingType.HOPPER
293+
swizzle_scale = SwizzlingType.HOPPER
294+
else:
295+
swizzle_value = None
296+
swizzle_scale = SwizzlingType.BLACKWELL
297+
else:
298+
swizzle_axis = None
299+
swizzle_value = None
300+
swizzle_scale = None
277301
w_tri, mx_scales_tri, weight_scale_shape = downcast_to_mxfp(w_tri, weight_dtype, axis=1,
278-
swizzle_axis=swizzle_axis)
279-
w_ref = upcast_from_mxfp(w_tri, mx_scales_tri, torch.bfloat16, axis=1, swizzle_axis=swizzle_axis)
280-
281-
precision_opt.mx_ctx = MicroscalingCtx(weight_scale=mx_scales_tri, swizzle_mx=swizzle_mx_scale,
302+
swizzle_axis=swizzle_axis,
303+
swizzle_value=swizzle_value,
304+
swizzle_scale=swizzle_scale)
305+
w_ref = upcast_from_mxfp(w_tri, mx_scales_tri, torch.bfloat16, axis=1, swizzle_axis=swizzle_axis,
306+
swizzle_value=swizzle_value, swizzle_scale=swizzle_scale)
307+
308+
precision_opt.mx_ctx = MicroscalingCtx(weight_scale=mx_scales_tri, swizzle_value=swizzle_value,
309+
swizzle_scale=swizzle_scale,
282310
actual_weight_scale_shape=weight_scale_shape)
283311

284312
if is_persistent and not can_use_persistent_tma(x_tri, w_tri, gindx, precision_opt):

0 commit comments

Comments
 (0)