Skip to content

Commit de846c0

Browse files
authored
[KERNELS] added missing mxfp4 tests (#7591)
1 parent fde96e8 commit de846c0

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
import torch
3+
from triton_kernels.tensor_details.layout import BlackwellMXScaleLayout
4+
5+
# ------------------------------------------------------------
6+
# Torch tests
7+
# ------------------------------------------------------------
8+
9+
10+
@pytest.mark.parametrize(
11+
"shape",
12+
[
13+
(3, 4096, 1024),
14+
(10, 254, 60),
15+
(1, 320, 160),
16+
(2, 16, 512),
17+
(3, 2, 36),
18+
],
19+
)
20+
def test_mxfp4_scale_roundtrip(shape):
21+
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
22+
layout = BlackwellMXScaleLayout(x.shape)
23+
res = layout.unswizzle_data(layout.swizzle_data(x))
24+
assert (res == x).all()
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import pytest
2+
from triton_kernels.tensor import wrap_torch_tensor, convert_layout, FP4
3+
from triton_kernels.tensor_details.layout import HopperMXScaleLayout, HopperMXValueLayout
4+
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
5+
from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
6+
from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
7+
import triton.language as tl
8+
import triton
9+
import torch
10+
11+
# ------------------------------------------------------------
12+
# Torch tests
13+
# ------------------------------------------------------------
14+
15+
16+
@pytest.mark.parametrize("shape", [(16, 32), (16, 64), (32, 32), (32, 64), (64, 128), (128, 128)])
17+
@pytest.mark.parametrize("trans", [False, True])
18+
@pytest.mark.parametrize("mx_axis", [0, 1])
19+
@pytest.mark.parametrize("mma_version", [2, 3])
20+
def test_mxfp4_value_roundtrip(shape, trans, mx_axis, mma_version):
21+
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
22+
if trans:
23+
x = x.mT
24+
if x.shape[1 - mx_axis] < 32:
25+
pytest.skip("Not enough elements along non-mx axis")
26+
layout = HopperMXValueLayout(x.shape, mx_axis, mma_version)
27+
res = layout.unswizzle_data(layout.swizzle_data(x))
28+
assert (res == x).all()
29+
30+
31+
@pytest.mark.parametrize("mx_axis", [0, 1])
32+
@pytest.mark.parametrize("num_warps", [4, 8])
33+
@pytest.mark.parametrize("shape", [(256, 64), (256, 128), (256, 256)])
34+
def test_mxfp4_scale_roundtrip(shape, mx_axis, num_warps):
35+
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
36+
layout = HopperMXScaleLayout(x.shape, mx_axis=mx_axis, num_warps=num_warps)
37+
res = layout.unswizzle_data(layout.swizzle_data(x))
38+
assert (res[:shape[0], :shape[1]] == x).all()
39+
40+
41+
# ------------------------------------------------------------
42+
# Triton tests
43+
# ------------------------------------------------------------
44+
45+
# ------------------ upcast mxfp4 to bf16 --------------------
46+
47+
48+
@triton.jit
49+
def _upcast_mxfp4_to_bf16(Y, X, XScale, x_stride_m, x_stride_n, x_scale_stride_m, x_scale_stride_n, y_stride_m,
50+
y_stride_n, X_BLOCK_M: tl.constexpr, X_BLOCK_N: tl.constexpr, Y_BLOCK_M: tl.constexpr,
51+
Y_BLOCK_N: tl.constexpr, SCALE_BLOCK_M: tl.constexpr, SCALE_BLOCK_N: tl.constexpr,
52+
mx_axis: tl.constexpr):
53+
offs_m_val = tl.arange(0, X_BLOCK_M)
54+
offs_n_val = tl.arange(0, X_BLOCK_N)
55+
offs_m_scale = tl.arange(0, SCALE_BLOCK_M)
56+
offs_n_scale = tl.arange(0, SCALE_BLOCK_N)
57+
# load values
58+
offs_x = offs_m_val[:, None] * x_stride_m + offs_n_val[None, :] * x_stride_n
59+
x = tl.load(X + offs_x)
60+
# load scales
61+
offs_x_scale = offs_m_scale[:, None] * x_scale_stride_m + offs_n_scale[None, :] * x_scale_stride_n
62+
x_scale = tl.load(XScale + offs_x_scale)
63+
x_scale = unswizzle_mxfp4_scale_hopper(x_scale, mx_axis=mx_axis, num_warps=tl.extra.cuda.num_warps())
64+
y = mxfp4_to_bf16_triton(x, x_scale, mx_axis=mx_axis)
65+
# write back output
66+
offs_m_val = tl.arange(0, Y_BLOCK_M)
67+
offs_n_val = tl.arange(0, Y_BLOCK_N)
68+
offs_y = offs_m_val[:, None] * y_stride_m + offs_n_val[None, :] * y_stride_n
69+
tl.store(Y + offs_y, y)
70+
71+
72+
def test_upcast_mxfp4_to_bf16():
73+
mx_axis = 0
74+
num_warps = 4
75+
torch.manual_seed(0)
76+
torch.cuda.manual_seed(0)
77+
shape = (256, 128)
78+
x = torch.randn(shape, dtype=torch.bfloat16, device="cuda")
79+
x_fp4_val, x_fp4_scale = downcast_to_mxfp(x, torch.uint8, axis=mx_axis)
80+
x_bf16 = upcast_from_mxfp(x_fp4_val, x_fp4_scale, x.dtype, axis=mx_axis)
81+
x_fp4_val = wrap_torch_tensor(x_fp4_val, dtype=FP4)
82+
x_fp4_scale = wrap_torch_tensor(x_fp4_scale)
83+
x_fp4_val = convert_layout(x_fp4_val, HopperMXValueLayout, mx_axis=mx_axis)
84+
x_fp4_scale = convert_layout(x_fp4_scale, HopperMXScaleLayout, mx_axis=mx_axis, num_warps=num_warps)
85+
y = torch.empty_like(x_bf16)
86+
_upcast_mxfp4_to_bf16[(1, )](
87+
y, x_fp4_val.storage.data, x_fp4_scale.storage.data, #
88+
x_fp4_val.storage.data.stride(0), x_fp4_val.storage.data.stride(1), #
89+
x_fp4_scale.storage.data.stride(0), x_fp4_scale.storage.data.stride(1), #
90+
y.stride(0), y.stride(1), #
91+
x_fp4_val.storage.data.shape[0], x_fp4_val.storage.data.shape[1], #
92+
shape[0], shape[1], #
93+
x_fp4_scale.storage.data.shape[0], x_fp4_scale.storage.data.shape[1], #
94+
mx_axis=mx_axis, num_warps=num_warps)
95+
assert (y == x_bf16).all()

0 commit comments

Comments
 (0)