Skip to content

Commit fac1734

Browse files
[mxfp] support quant/dequant from/to fp32 (triton-lang#7672)
It's slightly inconvenient to go through fp16/bf16 when we want to (de)quantize mxfp from/to fp32 # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests -> Added more test cases in ```pytest -xs python/triton_kernels/tests/test_mxfp.py``` - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 9eee56d commit fac1734

File tree

4 files changed

+19
-17
lines changed

4 files changed

+19
-17
lines changed

python/triton_kernels/tests/test_mxfp.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pytest
22
import torch
3-
43
from triton_kernels.numerics_details.mxfp import (
54
DequantScaleRoundingMode,
65
downcast_to_mxfp,
@@ -16,7 +15,7 @@ def dtype_str_to_torch(dtype_str: str) -> torch.dtype:
1615
return torch.uint8 if dtype_str == "float4_e2m1" else getattr(torch, dtype_str)
1716

1817

19-
@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16"])
18+
@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16", "float32"])
2019
def test_mxfp4_rounding_cases(dst_dtype):
2120
dst_dtype = dtype_str_to_torch(dst_dtype)
2221
x = torch.tensor([6, 0, 0.24, 0.25, 0.75, 0.99, 1.2, 1.3]).cuda().bfloat16().view(1, -1, 1)
@@ -33,7 +32,7 @@ def test_mxfp4_rounding_cases(dst_dtype):
3332

3433

3534
@pytest.mark.parametrize("src_dtype", ["float4_e2m1", "float8_e5m2", "float8_e4m3fn"])
36-
@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16"])
35+
@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16", "float32"])
3736
def test_mxfp_quant_dequant(src_dtype, dst_dtype):
3837
if "float8" in src_dtype and torch.cuda.get_device_capability()[0] < 9:
3938
pytest.skip("Float8 not tested on A100")
@@ -79,7 +78,7 @@ def test_mxfp_quant_dequant(src_dtype, dst_dtype):
7978
],
8079
)
8180
# fmt: on
82-
@pytest.mark.parametrize("dequant_dtype", ["float16", "bfloat16"])
81+
@pytest.mark.parametrize("dequant_dtype", ["float16", "bfloat16", "float32"])
8382
def test_mxfp_casting(
8483
shape: tuple[int, ...],
8584
axis: int,

python/triton_kernels/triton_kernels/numerics_details/mxfp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dty
8383
assert tensor.dtype in {torch.uint8, torch.float8_e5m2, torch.float8_e4m3fn}, \
8484
f"Invalid tensor dtype {tensor.dtype=}"
8585
assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}"
86-
assert dtype in (torch.float16, torch.bfloat16), f"Invalid output dtype {dtype=}"
86+
assert dtype in (torch.float16, torch.bfloat16, torch.float32), f"Invalid output dtype {dtype=}"
8787
# upcast
8888
logical_quant_dim = tensor.shape[axis] * (2 if tensor.dtype == torch.uint8 else 1)
8989
tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous()

python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.cons
107107

108108
src_dtype: tl.constexpr = src_ptr.dtype.element_ty
109109
tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8")
110-
tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16), f"{src_dtype=} must be bfloat16 or float16")
110+
tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16) or (src_dtype == tl.float32), f"{src_dtype=} must be bfloat16 or float16 or float32")
111111
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
112112

113113
outer_block = tl.program_id(0).to(tl.int64)

python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import triton
22
import triton.language as tl
3+
34
from ._downcast_to_mxfp import MXFP_BLOCK_SIZE
45

56

@@ -14,7 +15,7 @@ def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_
1415
# uint8 signifies two fp4 e2m1 values packed into a single byte
1516
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
1617
dst_dtype: tl.constexpr = out_ptr.dtype.element_ty
17-
tl.static_assert(dst_dtype == tl.float16 or dst_dtype == tl.bfloat16)
18+
tl.static_assert(dst_dtype == tl.float16 or dst_dtype == tl.bfloat16 or dst_dtype == tl.float32)
1819
tl.static_assert(
1920
mx_tensor_dtype == tl.uint8
2021
or ((mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) or mx_tensor_dtype == dst_dtype),
@@ -69,32 +70,33 @@ def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_
6970
if dst_dtype == tl.bfloat16:
7071
dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True)
7172
else:
72-
tl.static_assert(dst_dtype == tl.float16)
7373
dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
74-
dst_scale = dst_scale.to(tl.float16)
74+
if dst_dtype == tl.float16:
75+
dst_scale = dst_scale.to(tl.float16)
7576

7677
# Now upcast the tensor.
78+
intermediate_dtype: tl.constexpr = tl.bfloat16 if dst_dtype == tl.float32 else dst_dtype
7779
if is_fp8:
78-
dst_tensor = tensor.to(dst_dtype)
80+
dst_tensor = tensor.to(intermediate_dtype)
7981
if tensor.dtype == tl.float8e5:
8082
from_e_bits: tl.constexpr = 5
8183
from_m_bits: tl.constexpr = 2
82-
to_e_bits: tl.constexpr = 8 if dst_dtype == tl.bfloat16 else 5
83-
to_m_bits: tl.constexpr = 7 if dst_dtype == tl.bfloat16 else 10
84+
to_e_bits: tl.constexpr = 8 if intermediate_dtype == tl.bfloat16 else 5
85+
to_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
8486

8587
# Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
8688
non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits
8789
non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits
8890
dst_tensor = tl.where(
8991
(tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) == non_finite_mask_src,
90-
(dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(dst_dtype, bitcast=True),
92+
(dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(intermediate_dtype, bitcast=True),
9193
dst_tensor,
9294
)
9395
else:
9496
assert is_fp4
95-
dst_bias: tl.constexpr = 127 if dst_dtype == tl.bfloat16 else 15
96-
dst_0p5: tl.constexpr = 16128 if dst_dtype == tl.bfloat16 else 0x3800
97-
dst_m_bits: tl.constexpr = 7 if dst_dtype == tl.bfloat16 else 10
97+
dst_bias: tl.constexpr = 127 if intermediate_dtype == tl.bfloat16 else 15
98+
dst_0p5: tl.constexpr = 16128 if intermediate_dtype == tl.bfloat16 else 0x3800
99+
dst_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
98100
# e2m1
99101
em0 = tensor & 0x07
100102
em1 = tensor & 0x70
@@ -108,7 +110,8 @@ def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_
108110
x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
109111
x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
110112
# 3) x is zero, do nothing
111-
dst_tensor = tl.interleave(x0, x1).to(dst_dtype, bitcast=True)
113+
dst_tensor = tl.interleave(x0, x1).to(intermediate_dtype, bitcast=True)
114+
dst_tensor = dst_tensor.to(dst_dtype)
112115

113116
# Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping.
114117
dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])

0 commit comments

Comments
 (0)