|
| 1 | +import pytest |
| 2 | +from triton_kernels.tensor import wrap_torch_tensor, convert_layout, FP4 |
| 3 | +from triton_kernels.tensor_details.layout import HopperMXScaleLayout, HopperMXValueLayout |
| 4 | +from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp |
| 5 | +from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton |
| 6 | +from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper |
| 7 | +import triton.language as tl |
| 8 | +import triton |
| 9 | +import torch |
| 10 | + |
| 11 | +# ------------------------------------------------------------ |
| 12 | +# Torch tests |
| 13 | +# ------------------------------------------------------------ |
| 14 | + |
| 15 | + |
| 16 | +@pytest.mark.parametrize("shape", [(16, 32), (16, 64), (32, 32), (32, 64), (64, 128), (128, 128)]) |
| 17 | +@pytest.mark.parametrize("trans", [False, True]) |
| 18 | +@pytest.mark.parametrize("mx_axis", [0, 1]) |
| 19 | +@pytest.mark.parametrize("mma_version", [2, 3]) |
| 20 | +def test_mxfp4_value_roundtrip(shape, trans, mx_axis, mma_version): |
| 21 | + x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda") |
| 22 | + if trans: |
| 23 | + x = x.mT |
| 24 | + if x.shape[1 - mx_axis] < 32: |
| 25 | + pytest.skip("Not enough elements along non-mx axis") |
| 26 | + layout = HopperMXValueLayout(x.shape, mx_axis, mma_version) |
| 27 | + res = layout.unswizzle_data(layout.swizzle_data(x)) |
| 28 | + assert (res == x).all() |
| 29 | + |
| 30 | + |
| 31 | +@pytest.mark.parametrize("mx_axis", [0, 1]) |
| 32 | +@pytest.mark.parametrize("num_warps", [4, 8]) |
| 33 | +@pytest.mark.parametrize("shape", [(256, 64), (256, 128), (256, 256)]) |
| 34 | +def test_mxfp4_scale_roundtrip(shape, mx_axis, num_warps): |
| 35 | + x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda") |
| 36 | + layout = HopperMXScaleLayout(x.shape, mx_axis=mx_axis, num_warps=num_warps) |
| 37 | + res = layout.unswizzle_data(layout.swizzle_data(x)) |
| 38 | + assert (res[:shape[0], :shape[1]] == x).all() |
| 39 | + |
| 40 | + |
| 41 | +# ------------------------------------------------------------ |
| 42 | +# Triton tests |
| 43 | +# ------------------------------------------------------------ |
| 44 | + |
| 45 | +# ------------------ upcast mxfp4 to bf16 -------------------- |
| 46 | + |
| 47 | + |
| 48 | +@triton.jit |
| 49 | +def _upcast_mxfp4_to_bf16(Y, X, XScale, x_stride_m, x_stride_n, x_scale_stride_m, x_scale_stride_n, y_stride_m, |
| 50 | + y_stride_n, X_BLOCK_M: tl.constexpr, X_BLOCK_N: tl.constexpr, Y_BLOCK_M: tl.constexpr, |
| 51 | + Y_BLOCK_N: tl.constexpr, SCALE_BLOCK_M: tl.constexpr, SCALE_BLOCK_N: tl.constexpr, |
| 52 | + mx_axis: tl.constexpr): |
| 53 | + offs_m_val = tl.arange(0, X_BLOCK_M) |
| 54 | + offs_n_val = tl.arange(0, X_BLOCK_N) |
| 55 | + offs_m_scale = tl.arange(0, SCALE_BLOCK_M) |
| 56 | + offs_n_scale = tl.arange(0, SCALE_BLOCK_N) |
| 57 | + # load values |
| 58 | + offs_x = offs_m_val[:, None] * x_stride_m + offs_n_val[None, :] * x_stride_n |
| 59 | + x = tl.load(X + offs_x) |
| 60 | + # load scales |
| 61 | + offs_x_scale = offs_m_scale[:, None] * x_scale_stride_m + offs_n_scale[None, :] * x_scale_stride_n |
| 62 | + x_scale = tl.load(XScale + offs_x_scale) |
| 63 | + x_scale = unswizzle_mxfp4_scale_hopper(x_scale, mx_axis=mx_axis, num_warps=tl.extra.cuda.num_warps()) |
| 64 | + y = mxfp4_to_bf16_triton(x, x_scale, mx_axis=mx_axis) |
| 65 | + # write back output |
| 66 | + offs_m_val = tl.arange(0, Y_BLOCK_M) |
| 67 | + offs_n_val = tl.arange(0, Y_BLOCK_N) |
| 68 | + offs_y = offs_m_val[:, None] * y_stride_m + offs_n_val[None, :] * y_stride_n |
| 69 | + tl.store(Y + offs_y, y) |
| 70 | + |
| 71 | + |
| 72 | +def test_upcast_mxfp4_to_bf16(): |
| 73 | + mx_axis = 0 |
| 74 | + num_warps = 4 |
| 75 | + torch.manual_seed(0) |
| 76 | + torch.cuda.manual_seed(0) |
| 77 | + shape = (256, 128) |
| 78 | + x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") |
| 79 | + x_fp4_val, x_fp4_scale = downcast_to_mxfp(x, torch.uint8, axis=mx_axis) |
| 80 | + x_bf16 = upcast_from_mxfp(x_fp4_val, x_fp4_scale, x.dtype, axis=mx_axis) |
| 81 | + x_fp4_val = wrap_torch_tensor(x_fp4_val, dtype=FP4) |
| 82 | + x_fp4_scale = wrap_torch_tensor(x_fp4_scale) |
| 83 | + x_fp4_val = convert_layout(x_fp4_val, HopperMXValueLayout, mx_axis=mx_axis) |
| 84 | + x_fp4_scale = convert_layout(x_fp4_scale, HopperMXScaleLayout, mx_axis=mx_axis, num_warps=num_warps) |
| 85 | + y = torch.empty_like(x_bf16) |
| 86 | + _upcast_mxfp4_to_bf16[(1, )]( |
| 87 | + y, x_fp4_val.storage.data, x_fp4_scale.storage.data, # |
| 88 | + x_fp4_val.storage.data.stride(0), x_fp4_val.storage.data.stride(1), # |
| 89 | + x_fp4_scale.storage.data.stride(0), x_fp4_scale.storage.data.stride(1), # |
| 90 | + y.stride(0), y.stride(1), # |
| 91 | + x_fp4_val.storage.data.shape[0], x_fp4_val.storage.data.shape[1], # |
| 92 | + shape[0], shape[1], # |
| 93 | + x_fp4_scale.storage.data.shape[0], x_fp4_scale.storage.data.shape[1], # |
| 94 | + mx_axis=mx_axis, num_warps=num_warps) |
| 95 | + assert (y == x_bf16).all() |
0 commit comments