Skip to content

Commit a837a04

Browse files
[rfc] mxfp act in and out matmul (not for persistent kernel yet) (#7598)
Not (yet) optimized matmul with mxfp8 act input and output (e.g., mxfp8 act input/output not supported with persistent kernel yet). A few noteables: * Not working with Hopper swizzling (but is working with Blackwell weight value/scale swizzling) * Reusing actual_scale that is used for flexscale to pass mx scale in a couple of places not to grow number of arguments <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `added more test cases covering the changes in python/triton_kernels/tests/test_matmul.py`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent a2d179d commit a837a04

File tree

9 files changed

+295
-90
lines changed

9 files changed

+295
-90
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 76 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from dataclasses import dataclass, fields
1+
# isort: off
2+
# fmt: off
3+
from dataclasses import dataclass, fields, replace
24
import pytest
35
import torch
46
from typing import Union
@@ -7,14 +9,14 @@
79
from triton_kernels.routing import routing
810
# matmul utilities
911
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
10-
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, FusedActivation, FnSpecs
12+
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, FusedActivation, FnSpecs, FnName, Epilogue
1113
from triton_kernels.matmul_ogs import matmul_ogs_set_idle_sms, matmul_ogs, matmul_ogs_torch
1214
from triton_kernels.swiglu import swiglu, swiglu_fn, PrecisionConfig as SwiGLUPrecisionConfig
1315
from triton_kernels.tensor import convert_layout, wrap_torch_tensor, FP4
1416
from triton_kernels.tensor_details import layout
1517
# numerics utilities
1618
from triton_kernels.numerics import InFlexData, OutFlexData
17-
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
19+
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp, dequantize_mxfp8_fn, downcast_to_mxfp_torch, upcast_from_mxfp_torch, MXFP_BLOCK_SIZE
1820
# testing utilities
1921
from triton_kernels.testing import assert_close, compute_actual_scale
2022
# target-specific utilities
@@ -78,9 +80,8 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
7880
# ---------------
7981

8082

81-
def init_precision(out_dtype, weight_dtype, is_mixed_input, n_expts_tot=1, device="cuda"):
82-
act_use_flexpoint = out_dtype.itemsize == 1
83-
weight_use_flexpoint = weight_dtype.itemsize == 1 and not is_mixed_input
83+
def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, n_expts_tot=1, device="cuda"):
84+
weight_use_flexpoint = weight_dtype.itemsize == 1 and not weight_mxfp
8485
# flexpoint
8586
make_tensor = lambda val0, val1: torch.tensor([val0, val1] * (n_expts_tot // 2) +
8687
([val0]
@@ -106,13 +107,14 @@ def apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_config):
106107

107108
def apply(x, scale):
108109
if scale is None:
109-
return x.clone().detach().requires_grad_(True)
110+
x = x.clone()
110111
elif scale.numel() == 1:
111-
return (x.float() * scale).detach().requires_grad_(True)
112+
x = x.float() * scale
112113
else:
113114
assert x.ndim == 3
114115
assert scale.numel() == x.shape[0]
115-
return (x.float() * scale[:, None, None]).detach().requires_grad_(True)
116+
x = x.float() * scale[:, None, None]
117+
return x.detach().requires_grad_()
116118

117119
return (
118120
apply(x_tri, flex_ctx.lhs_data.scale),
@@ -215,6 +217,19 @@ class Case:
215217
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True),
216218
Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
217219
Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=False),
220+
Case(16, 256, 256, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
221+
Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
222+
Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1),
223+
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9),
224+
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
225+
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2),
226+
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
227+
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4),
228+
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
229+
Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4),
230+
Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True),
231+
Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4),
232+
Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True),
218233
# AMD
219234
Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"),
220235
Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 3, 1),
@@ -247,8 +262,12 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
247262
pytest.skip("Float8 not tested on A100")
248263
if "float16" in act_dtype_str and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] >= 10:
249264
pytest.skip("float16 x mx not supported with cuda capability >= 10")
250-
if "float8" in act_dtype_str and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 10:
251-
pytest.skip("float8 x mx not supported with cuda capability < 10")
265+
if weight_dtype_str.startswith("mx"):
266+
if "float8" in act_dtype_str and torch.cuda.get_device_capability()[0] < 10:
267+
pytest.skip("float8 x mx not supported with cuda capability < 10")
268+
if act_dtype_str == "mxfloat8_e4m3fn":
269+
if is_persistent:
270+
pytest.skip("mx x mx not supported with persistent kernel")
252271
if n == 2880 and k == 2880 and torch.cuda.get_device_capability()[0] < 9:
253272
pytest.skip("Not enough memory on A100")
254273

@@ -257,6 +276,8 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
257276
pytest.skip("float8 x mx only supported on CDNA4")
258277
if "float8" in act_dtype_str and "mxfloat8" in weight_dtype_str:
259278
pytest.skip("NYI: float8 x mxfloat8 not tested on AMD GPU")
279+
if act_dtype_str.startswith("mx") and weight_dtype_str.startswith("mx"):
280+
pytest.skip("NYI: mx x mx not tested on AMD GPU")
260281
if is_persistent:
261282
pytest.skip("NYI: Persistent kernel not supported on AMD GPU")
262283
if split_k > 1:
@@ -301,24 +322,30 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
301322
}
302323
opt_flags.update_opt_flags_constraints(constraints)
303324

304-
is_mixed_input = act_dtype_str != weight_dtype_str
305-
if weight_dtype_str.startswith("mx"):
325+
weight_mxfp = weight_dtype_str.startswith("mx")
326+
if weight_mxfp:
306327
weight_dtype_str = weight_dtype_str[2:]
328+
act_mxfp8 = act_dtype_str.startswith("mx")
329+
act_is_float8 = act_dtype_str.startswith("float8")
330+
if act_mxfp8:
331+
act_dtype_str = act_dtype_str[2:]
332+
dequantize_mxfp8_spec = FnSpecs(
333+
FnName.DEQUANTIZE_MXFP8.name, dequantize_mxfp8_fn, (), ()
334+
)
307335

308336
test_bwd = False
309337
weight_dtype = dtype_str_to_torch(weight_dtype_str)
310338
act_dtype = dtype_str_to_torch(act_dtype_str)
311-
act_is_float8 = act_dtype.itemsize == 1
312-
precision_opt = init_precision(act_dtype, weight_dtype, is_mixed_input, n_expts_tot // n_expt_shards, device=device)
339+
precision_opt = init_precision(act_dtype, act_is_float8, weight_dtype, weight_mxfp, n_expts_tot // n_expt_shards, device=device)
313340
# precision_opt.x_pad_trans_requires_flexpoint = False
314341
if mode == "ragged":
315342
m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter,
316343
device=device)
317344
else:
318345
rdata = gindx = sindx = None
319346
x_tri, w_tri, bias_tri, gs0_tri, gs1_tri = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act,
320-
n_expt_shards, mode, act_dtype, #
321-
torch.bfloat16 if is_mixed_input else weight_dtype,
347+
n_expt_shards, mode, torch.bfloat16 if act_mxfp8 else act_dtype, #
348+
torch.bfloat16 if weight_mxfp else weight_dtype,
322349
has_y_gammas, requires_grad=test_bwd, device=device)
323350
x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt)
324351

@@ -327,7 +354,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
327354
w_tri = w_tri.squeeze(0).detach().requires_grad_(test_bwd)
328355
w_ref = w_ref.squeeze(0).detach().requires_grad_(test_bwd)
329356

330-
if is_mixed_input:
357+
if weight_mxfp:
331358
mx_axis = w_tri.ndim - 2
332359
# compute layouts
333360
w_layout, w_layout_opts = layout.StridedLayout, dict()
@@ -346,6 +373,25 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
346373
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
347374
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
348375
precision_opt.weight_scale = w_scale_tri
376+
epilogue = None
377+
if act_mxfp8:
378+
x_tri, x_mx_scales_tri = downcast_to_mxfp(x_tri, act_dtype, axis=-1)
379+
x_ref = upcast_from_mxfp(x_tri, x_mx_scales_tri, torch.bfloat16, axis=-1)
380+
is_input_batched = x_tri.ndim == 3
381+
y_shape = x_tri.shape if is_input_batched else (1,) + x_tri.shape
382+
n_rows = y_shape[1] if gindx is None or mode == "batched" else gindx.dst_indx.shape[0]
383+
y_shape = (y_shape[0], n_rows, w_tri.shape[-1])
384+
if sindx is None or mode == "batched":
385+
if not is_input_batched:
386+
y_shape = (y_shape[1], y_shape[2])
387+
else:
388+
y_shape = (n_rows // rdata.n_expts_act, y_shape[-1])
389+
y_scale_shape = y_shape[:-1] + (triton.cdiv(y_shape[-1], MXFP_BLOCK_SIZE),)
390+
y_scale = torch.empty(y_scale_shape, dtype=torch.uint8, device=x_tri.device)
391+
precision_opt = replace(precision_opt, act_scale=x_mx_scales_tri, out_scale=y_scale)
392+
epilogue = Epilogue(dequantize_mxfp8_spec, tuple(), tuple(), effective_itemsize=6.0)
393+
else:
394+
y_scale = None
349395

350396
if test_launch_metadata:
351397

@@ -393,7 +439,7 @@ def _hook(launch_metadata):
393439

394440
# triton
395441
try:
396-
tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref)
442+
tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref, epilogue=epilogue)
397443
except (opt_flags.InapplicableConstraint, NotImplementedError):
398444
pytest.skip("inapplicable opt_flags constraint")
399445
# If split_k > 1, then the intermediate tensor is fp32.
@@ -432,7 +478,16 @@ def round_x(x, idx):
432478
assert n_rows > 0
433479
ref_y = ref_y[:n_rows]
434480
tri_y = tri_y[:n_rows]
435-
assert_close(scale(ref_y, flex.out_data.expected_scale), tri_y)
481+
if act_mxfp8:
482+
tri_y = upcast_from_mxfp(tri_y, precision_opt.out_scale, dtype=torch.bfloat16, axis=-1).to(ref_y.dtype)
483+
ref_y_quant, ref_y_scale = downcast_to_mxfp_torch(ref_y, act_dtype, axis=-1)
484+
ref_y = upcast_from_mxfp_torch(ref_y_quant, ref_y_scale, target_dtype=ref_y.dtype, axis=-1)
485+
maxtol = 4e-1
486+
rmstol = 4e-2
487+
else:
488+
maxtol = None
489+
rmstol = None
490+
assert_close(scale(ref_y, flex.out_data.expected_scale), tri_y, maxtol=maxtol, rmstol=rmstol)
436491

437492
if act_is_float8:
438493
tri_y_scale = flex.out_data.actual_scale.clone()
@@ -495,7 +550,7 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter,
495550
else:
496551
rdata = gindx = sindx = None
497552

498-
precision_opt = init_precision(act_dtype, weight_dtype, False, n_expts_tot // n_expt_shards, device=device)
553+
precision_opt = init_precision(act_dtype, str(act_dtype).startswith("torch.float8"), weight_dtype, False, n_expts_tot // n_expt_shards, device=device)
499554
x, w, bias, _, _ = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode,
500555
act_dtype, weight_dtype, False, requires_grad=False, device=device)
501556

0 commit comments

Comments
 (0)