Skip to content

Commit c0175fa

Browse files
authored
[KERNELS] refactor tma/mxfp/matmul (#7405)
1 parent 91d58f5 commit c0175fa

31 files changed

+1246
-1718
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import torch
77
import triton_kernels
88
import triton_kernels.swiglu
9-
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, SwizzlingType
9+
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
10+
from triton_kernels.tensor import SwizzlingType, swizzle
1011
from triton_kernels.matmul_ogs import MicroscalingCtx, matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
1112
from triton_kernels.numerics import InFlexData
1213
from triton_kernels.routing import routing
@@ -35,14 +36,12 @@ def quantize(w, dtype, dev, **opt):
3536
assert dtype == "mx4", f"{dtype=}"
3637
swizzle_mx_scale = opt.get("swizzle_mx_scale", None)
3738
swizzle_mx_value = opt.get("swizzle_mx_value", None)
38-
swizzle_axis = 2 if swizzle_mx_scale else None
3939
w = w.to(torch.bfloat16)
40-
w, mx_scales, weight_scale_shape = downcast_to_mxfp(w, torch.uint8, axis=1, swizzle_axis=swizzle_axis,
41-
swizzle_scale=swizzle_mx_scale,
42-
swizzle_value=swizzle_mx_value)
40+
w, mx_scales = downcast_to_mxfp(w, torch.uint8, axis=1)
41+
w = swizzle(w, swizzle_mx_value)
42+
mx_scales = swizzle(mx_scales, swizzle_mx_scale)
4343
return w, InFlexData(), MicroscalingCtx(weight_scale=mx_scales, swizzle_scale=swizzle_mx_scale,
44-
swizzle_value=swizzle_mx_value,
45-
actual_weight_scale_shape=weight_scale_shape)
44+
swizzle_value=swizzle_mx_value)
4645

4746

4847
@dataclass
@@ -111,11 +110,11 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
111110
swizzle_mx_value = None
112111
swizzle_mx_scale = None
113112
elif torch.cuda.get_device_capability()[0] < 10:
114-
swizzle_mx_value = SwizzlingType.HOPPER
115-
swizzle_mx_scale = SwizzlingType.HOPPER
113+
swizzle_mx_value = SwizzlingType.HOPPER_VALUE
114+
swizzle_mx_scale = SwizzlingType.HOPPER_SCALE
116115
else:
117116
swizzle_mx_value = None
118-
swizzle_mx_scale = SwizzlingType.BLACKWELL
117+
swizzle_mx_scale = SwizzlingType.BLACKWELL_SCALE
119118
opt1 = {"swizzle_mx_value": swizzle_mx_value, "swizzle_mx_scale": swizzle_mx_scale}
120119
opt2 = deepcopy(opt1)
121120
wg, wg_flex, wg_mx = quantize(wg, "bf16", dev, **optg)
@@ -216,7 +215,7 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
216215
batch_ranges_moe = [(128, 512, 32), (512, 32000, 128)]
217216
dense_dtypes = ["fp8", "fp8"]
218217
quantized_dtypes = ["fp8", "mx4"] if has_native_mx4 else ["bf16", "mx4"]
219-
roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *dense_dtypes, TP=1, EP=1, name="dense")
218+
# roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *dense_dtypes, TP=1, EP=1, name="dense")
220219
roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *quantized_dtypes, TP=1, EP=1, name="dense")
221-
roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *dense_dtypes, TP=1, EP=1, name="llama4-maverick")
222-
roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *quantized_dtypes, TP=1, EP=1, name="llama4-maverick")
220+
# roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *dense_dtypes, TP=1, EP=1, name="llama4-maverick")
221+
# roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *quantized_dtypes, TP=1, EP=1, name="llama4-maverick")

python/triton_kernels/tests/test_matmul.py

Lines changed: 55 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
from triton_kernels.routing import routing
88
# matmul utilities
99
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
10-
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, MicroscalingCtx, FusedActivation, FnSpecs
11-
from triton_kernels.matmul_ogs import can_use_persistent_tma
10+
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, FusedActivation, FnSpecs
1211
from triton_kernels.matmul_ogs import matmul_ogs_set_idle_sms, matmul_ogs, matmul_ogs_torch
1312
from triton_kernels.swiglu import swiglu, swiglu_fn, PrecisionConfig as SwiGLUPrecisionConfig
13+
from triton_kernels.tensor import convert_layout, wrap_torch_tensor, FP4
14+
from triton_kernels.tensor_details import layout
1415
# numerics utilities
1516
from triton_kernels.numerics import InFlexData, OutFlexData
16-
from triton_kernels.numerics_details.mxfp import SwizzlingType, downcast_to_mxfp, upcast_from_mxfp
17+
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
1718
# testing utilities
1819
from triton_kernels.testing import assert_close, compute_actual_scale
1920
# target-specific utilities
@@ -53,20 +54,22 @@ def init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_
5354
def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode, act_dtype, weight_dtype,
5455
has_y_gammas, requires_grad=True, device="cuda"):
5556
torch.manual_seed(0)
56-
assert mode in {'batched', 'ragged'}
57+
assert mode in {'batched', "plain", 'ragged'}
5758
in_m = m * (n_expts_act if gindx is None else 1)
5859
shape_x = (n_expts_tot, in_m, k) if mode == 'batched' else (in_m, k)
60+
shape_batch = tuple() if mode == "plain" else (n_expts_tot // n_expt_shards, )
5961
x = alloc_rand(shape_x, device=device, dtype=act_dtype, requires_grad=requires_grad)
60-
w = alloc_rand((n_expts_tot // n_expt_shards, k, n), device=device, dtype=weight_dtype, requires_grad=requires_grad)
61-
bias = alloc_rand((n_expts_tot // n_expt_shards, n), device=device, dtype=torch.float32,
62-
requires_grad=requires_grad)
62+
w = alloc_rand(shape_batch + (k, n), device=device, dtype=weight_dtype, requires_grad=requires_grad)
63+
bias = alloc_rand(shape_batch + (n, ), device=device, dtype=torch.float32, requires_grad=requires_grad)
6364
gs0 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad)
6465
gs1 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad)
6566
gs0 = gs0.detach().requires_grad_(requires_grad)
6667
gs1 = gs1.detach().requires_grad_(requires_grad)
6768
if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
6869
gs0 = None
6970
gs1 = None
71+
if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
72+
w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
7073
return x, w, bias, gs0, gs1
7174

7275

@@ -75,7 +78,7 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
7578
# ---------------
7679

7780

78-
def init_precision(out_dtype, weight_dtype, is_mixed_input, n_expts_tot=1, mx_ctx=MicroscalingCtx(), device="cuda"):
81+
def init_precision(out_dtype, weight_dtype, is_mixed_input, n_expts_tot=1, device="cuda"):
7982
act_use_flexpoint = out_dtype.itemsize == 1
8083
weight_use_flexpoint = weight_dtype.itemsize == 1 and not is_mixed_input
8184
# flexpoint
@@ -95,7 +98,7 @@ def init_precision(out_dtype, weight_dtype, is_mixed_input, n_expts_tot=1, mx_ct
9598
out_data=out_flex_data(4.00, act_use_flexpoint),
9699
)
97100
return PrecisionConfig(flex_ctx=flex_ctx, acc_scale=2.0 if act_use_flexpoint or weight_use_flexpoint else 1.0,
98-
mx_ctx=mx_ctx, out_dtype=out_dtype)
101+
out_dtype=out_dtype)
99102

100103

101104
def apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_config):
@@ -183,8 +186,10 @@ class Case:
183186
Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2),
184187
Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2, split_k=9),
185188
# mx types:
186-
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4),
187-
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
189+
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1),
190+
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True),
191+
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1),
192+
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True),
188193
Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2),
189194
Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
190195
Case(1000, 700, 700, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9),
@@ -198,10 +203,10 @@ class Case:
198203
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
199204
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
200205
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1),
201-
Case(1000, 704, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9),
202-
Case(1000, 704, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
203-
Case(1000, 704, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2),
204-
Case(1000, 704, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
206+
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9),
207+
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
208+
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2),
209+
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
205210
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4),
206211
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
207212
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4),
@@ -317,38 +322,32 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
317322
has_y_gammas, requires_grad=test_bwd, device=device)
318323
x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt)
319324

320-
if is_mixed_input:
321-
if hbm_swizzling:
322-
swizzle_axis = 2
323-
if torch.cuda.get_device_capability()[0] < 10:
324-
swizzle_value = SwizzlingType.HOPPER
325-
swizzle_scale = SwizzlingType.HOPPER
326-
else:
327-
swizzle_value = None
328-
swizzle_scale = SwizzlingType.BLACKWELL
329-
else:
330-
swizzle_axis = None
331-
swizzle_value = None
332-
swizzle_scale = None
333-
w_tri, mx_scales_tri, weight_scale_shape = downcast_to_mxfp(w_tri, weight_dtype, axis=1,
334-
swizzle_axis=swizzle_axis,
335-
swizzle_value=swizzle_value,
336-
swizzle_scale=swizzle_scale)
337-
w_ref = upcast_from_mxfp(w_tri, mx_scales_tri, torch.bfloat16, axis=1, swizzle_axis=swizzle_axis,
338-
swizzle_value=swizzle_value, swizzle_scale=swizzle_scale)
339-
340-
precision_opt.mx_ctx = MicroscalingCtx(weight_scale=mx_scales_tri, swizzle_value=swizzle_value,
341-
swizzle_scale=swizzle_scale,
342-
actual_weight_scale_shape=weight_scale_shape)
343-
344-
if is_persistent and not can_use_persistent_tma(x_tri, w_tri, gindx, precision_opt):
345-
pytest.skip("persistent TMAs not supported for this test")
346-
347325
if w_tri.shape[0] == 1:
348326
# Test the case when weight has dim 2, i.e., shape (K, N).
349327
w_tri = w_tri.squeeze(0).detach().requires_grad_(test_bwd)
350328
w_ref = w_ref.squeeze(0).detach().requires_grad_(test_bwd)
351329

330+
if is_mixed_input:
331+
capability_major = torch.cuda.get_device_capability()[0]
332+
w_layout = layout.StridedLayout
333+
w_scale_layout = layout.StridedLayout
334+
if hbm_swizzling and "float4" in weight_dtype_str:
335+
# weight layout
336+
w_layouts = {9: layout.HopperMXValueLayout}
337+
w_layout = w_layouts.get(capability_major, layout.StridedLayout)
338+
# weight scale layout
339+
w_scales_layouts = {9: layout.HopperMXScaleLayout, 10: layout.BlackwellMXScaleLayout}
340+
w_scale_layout = w_scales_layouts.get(capability_major, layout.StridedLayout)
341+
w_tri, mx_scales_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=-2)
342+
w_ref = upcast_from_mxfp(w_tri, mx_scales_tri, torch.bfloat16, axis=-2)
343+
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
344+
w_tri = convert_layout(wrap_torch_tensor(w_tri, w_tri_dtype), w_layout)
345+
mx_scales_tri = convert_layout(wrap_torch_tensor(mx_scales_tri), w_scale_layout)
346+
precision_opt.weight_scale = mx_scales_tri
347+
348+
# if not is_persistent and precision_opt.weight_scale is not None:
349+
# pytest.skip("non-persistent not supported with mxfp")
350+
352351
if test_launch_metadata:
353352

354353
def _clobber(t, used_mask):
@@ -394,7 +393,10 @@ def _hook(launch_metadata):
394393
flex = precision_opt.flex_ctx
395394

396395
# triton
397-
tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref)
396+
try:
397+
tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref)
398+
except (opt_flags.InapplicableConstraint, NotImplementedError):
399+
pytest.skip("inapplicable opt_flags constraint")
398400
# If split_k > 1, then the intermediate tensor is fp32.
399401
sep_gather = mode == "ragged" and do_gather and n_expts_act > 1 and split_k == 1
400402
sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1
@@ -498,16 +500,16 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter,
498500
x, w, bias, _, _ = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode,
499501
act_dtype, weight_dtype, False, requires_grad=False, device=device)
500502

501-
if is_persistent and not can_use_persistent_tma(x.view(1, x.shape[-2], x.shape[-1]),
502-
w.view(1, w.shape[-2], w.shape[-1]), gindx, precision_opt):
503-
pytest.skip("persistent TMAs not supported for this test")
504-
505503
if mode == "batched":
506504
rdata, gindx, sindx = None, None, None
507-
a = swiglu(matmul_ogs(x, w, bias, rdata, gindx, sindx, precision_opt), swiglu_alpha,
508-
precision_config=SwiGLUPrecisionConfig(swiglu_limit))
509-
b = matmul_ogs(
510-
x, w, bias, rdata, gindx, sindx, precision_opt,
511-
fused_activation=FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (swiglu_alpha, swiglu_limit),
512-
2))
505+
506+
try:
507+
a = swiglu(matmul_ogs(x, w, bias, rdata, gindx, sindx, precision_opt), swiglu_alpha,
508+
precision_config=SwiGLUPrecisionConfig(swiglu_limit))
509+
b = matmul_ogs(
510+
x, w, bias, rdata, gindx, sindx, precision_opt,
511+
fused_activation=FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
512+
(swiglu_alpha, swiglu_limit), 2))
513+
except opt_flags.InapplicableConstraint:
514+
pytest.skip("inapplicable constraint")
513515
assert_close(a, b)

0 commit comments

Comments
 (0)