Skip to content

Commit 2975ee1

Browse files
[release/2.8] Add mx fp4 support (#2472)
mx fp8 is enabled though cherrypick patch from rel 2.7. This patch adds support to enable mx fp4. PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 452 tests in 23.776s OK (skipped=340) Passed 112 --------- Signed-off-by: Jagadish Krishnamoorthy <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent e96dc85 commit 2975ee1

File tree

4 files changed

+71
-38
lines changed

4 files changed

+71
-38
lines changed

aten/src/ATen/cuda/CUDADataType.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type)
9090
case c10::ScalarType::Float8_e5m2fnuz:
9191
return HIP_R_8F_E5M2_FNUZ;
9292
#endif
93-
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
93+
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12080) || (defined(USE_ROCM) && ROCM_VERSION >= 70000)
9494
case c10::ScalarType::Float4_e2m1fn_x2:
9595
return CUDA_R_4F_E2M1;
9696
#endif

aten/src/ATen/cuda/tunable/GemmHipblaslt.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@ constexpr hipDataType HipDataTypeFor<c10::Float8_e8m0fnu>() {
8585
return static_cast<hipDataType>(500);
8686
}
8787

88+
template <>
89+
constexpr hipDataType HipDataTypeFor<c10::Float4_e2m1fn_x2>() {
90+
#if ROCM_VERSION >= 70000
91+
return HIP_R_4F_E2M1;
92+
#else
93+
return static_cast<hipDataType>(33);
94+
#endif
95+
}
96+
8897
template <typename T>
8998
int GetBatchFromParams(const GemmParams<T>* params) {
9099
return 1;

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,6 +1284,17 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
12841284
if (use_fast_accum) {
12851285
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float4_e2m1fn_x2 && mat2.scalar_type() != ScalarType::Float4_e2m1fn_x2, "`use_fast_accum` is not supported when `mat1` or `mat2` tensors have the `Float4_e2m1fn_x2` dtype.");
12861286
}
1287+
#ifdef USE_ROCM
1288+
if (mat1.scalar_type() == ScalarType::Float4_e2m1fn_x2 || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2) {
1289+
TORCH_CHECK(ROCM_VERSION >= 70000, "Float4_e2m1fn_x2 is only supported for ROCm 7.0 and above");
1290+
}
1291+
if (mat1.scalar_type() == ScalarType::Float8_e5m2 || mat2.scalar_type() == ScalarType::Float8_e5m2) {
1292+
TORCH_CHECK(ROCM_VERSION >= 70000, "Float8_e5m2 is only supported for ROCm 7.0 and above");
1293+
}
1294+
if (mat1.scalar_type() == ScalarType::Float8_e4m3fn || mat2.scalar_type() == ScalarType::Float8_e4m3fn) {
1295+
TORCH_CHECK(ROCM_VERSION >= 70000, "Float8_e4m3fn is only supported for ROCm 7.0 and above");
1296+
}
1297+
#endif
12871298
if (bias) {
12881299
TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32");
12891300
TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half,

test/test_matmul_cuda.py

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,8 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
882882

883883
# largest power of 2 representable in `torch.float8_e4m3fn`
884884
F8E4M3_LARGEST_POW2 = 8
885+
# largest power of 2 representable in `torch.float4_e2m1fn_x2`
886+
FP4E2M1FN_LARGEST_POW2 = 1.0
885887
# max value of `torch.float8_e4m3fn` (448)
886888
F8E4M3_MAX_VAL = torch.finfo(torch.float8_e4m3fn).max
887889
# exponent bias of `torch.float8_e8m0fnu`
@@ -890,14 +892,20 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
890892
FP4_EBITS, FP4_MBITS = 2, 1
891893
FP4_MAX_VAL = 6.0
892894

893-
def data_to_mx_scale(x, block_size):
895+
def data_to_mx_scale(x, block_size, recipe):
894896
# simple implementation of https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
895897
# section 6.3, not all edge cases (such as NaN) are handled/tested
898+
if recipe == "mxfp8":
899+
largest_pow2 = F8E4M3_LARGEST_POW2
900+
elif recipe == "mxfp4":
901+
largest_pow2 = FP4E2M1FN_LARGEST_POW2
902+
else:
903+
raise ValueError(f"data_to_mx_scale(): Unsupported mx recipe: {recipe}")
896904
orig_shape = x.shape
897905
x = x.reshape(-1, block_size)
898906
max_abs = torch.amax(torch.abs(x), 1)
899907
largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs))
900-
scale_e8m0_unbiased = largest_p2_lt_max_abs - F8E4M3_LARGEST_POW2
908+
scale_e8m0_unbiased = largest_p2_lt_max_abs - largest_pow2
901909
scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, -1 * F8E8M0_EXP_BIAS, F8E8M0_EXP_BIAS)
902910
scale_e8m0_biased = scale_e8m0_unbiased + F8E8M0_EXP_BIAS
903911
scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8)
@@ -1446,20 +1454,21 @@ def test_pack_uint4(self):
14461454
(127, 96, 1024),
14471455
(1025, 128, 96)
14481456
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
1449-
@parametrize("recipe", ["mxfp8", "nvfp4"])
1450-
def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
1451-
if recipe == "nvfp4" and fast_accum:
1452-
return unittest.skip("fast_accum not supported in nvfp4 cublas gemm, skipping")
1457+
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
1458+
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
1459+
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
1460+
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
14531461

14541462
device = "cuda"
14551463
M, K, N = mkn
14561464
if torch.version.hip:
14571465
if not (M % 32 == 0 and K % 32 == 0 and N % 32 == 0):
14581466
raise unittest.SkipTest("Matrix dimensions must be multiples of 32 on ROCm, skipping")
14591467

1460-
if recipe == "nvfp4" and K % 32 != 0:
1461-
return unittest.skip("K must be divisible by 32 for nvfp4 cublas gemm, skipping")
1468+
if (recipe == "nvfp4" or recipe == "mxfp4") and K % 32 != 0:
1469+
raise unittest.SkipTest("K must be divisible by 32 for nvfp4/mxfp4 cublas gemm, skipping")
14621470

1471+
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
14631472
BLOCK_SIZE = 16 if recipe == "nvfp4" else 32
14641473
require_exact_match = True
14651474
approx_match_sqnr_target = 22.0
@@ -1475,11 +1484,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
14751484
B = B_ref.to(torch.float8_e4m3fn)
14761485
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
14771486
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
1478-
else: # nvfp4
1487+
else: # nvfp4 # mxfp4
14791488
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
14801489
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
1481-
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
1482-
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
1490+
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
1491+
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
14831492

14841493
elif test_case_name == "a_ones_b_ones":
14851494
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
@@ -1490,11 +1499,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
14901499
B = B_ref.to(torch.float8_e4m3fn)
14911500
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
14921501
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
1493-
else: # nvfp4
1502+
else: # nvfp4 # mxfp4
14941503
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
14951504
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
1496-
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
1497-
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
1505+
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
1506+
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
14981507

14991508
elif test_case_name == "a_ones_modified_b_ones":
15001509
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
@@ -1506,11 +1515,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
15061515
B = B_ref.to(torch.float8_e4m3fn)
15071516
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
15081517
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
1509-
else: # nvfp4
1518+
else: # nvfp4 # mxfp4
15101519
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
15111520
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
1512-
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
1513-
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
1521+
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
1522+
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
15141523

15151524
elif test_case_name == "a_ones_b_ones_modified":
15161525
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
@@ -1522,11 +1531,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
15221531
B = B_ref.to(torch.float8_e4m3fn)
15231532
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
15241533
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
1525-
else: # nvfp4
1534+
else: # nvfp4 # mxfp4
15261535
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
15271536
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
1528-
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
1529-
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
1537+
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
1538+
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
15301539

15311540
elif test_case_name == "a_scale_modified_b_ones":
15321541
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
@@ -1540,11 +1549,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
15401549
A_ref[1][0:BLOCK_SIZE] = 4
15411550
A[1][0:BLOCK_SIZE] = 2
15421551
A_scale[1][0] = 2
1543-
else: # nvfp4
1552+
else: # nvfp4 # mxfp4
15441553
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
15451554
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
1546-
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
1547-
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
1555+
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
1556+
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
15481557
A_ref[1][0:BLOCK_SIZE] = 4
15491558
A.view(torch.uint8)[1][0:(BLOCK_SIZE // 2)] = 0b01000100
15501559
A_scale[1][0] = 2
@@ -1561,11 +1570,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
15611570
B_ref[1][0:BLOCK_SIZE] = 4
15621571
B[1][0:BLOCK_SIZE] = 2
15631572
B_scale[1][0] = 2
1564-
else: # nvfp4
1573+
else: # nvfp4 # mxfp4
15651574
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
15661575
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
1567-
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
1568-
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
1576+
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
1577+
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
15691578
B_ref[1][0:BLOCK_SIZE] = 4
15701579
B.view(torch.uint8)[1][0:(BLOCK_SIZE // 2)] = 0b01000100
15711580
B_scale[1][0] = 2
@@ -1585,7 +1594,7 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
15851594
B = B_ref.to(torch.float8_e4m3fn)
15861595
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
15871596
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
1588-
else: # nvfp4
1597+
else: # nvfp4 # mxfp4
15891598
# scales all-ones, element data random while being exactly representable in float4_e2m1fn_x2
15901599
# generate integers in [0, 16] and cast to bfloat16
15911600
A_ref = _floatx_unpacked_to_f32(
@@ -1600,8 +1609,8 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
16001609
).bfloat16()
16011610
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
16021611
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
1603-
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
1604-
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
1612+
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
1613+
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
16051614

16061615
elif test_case_name == "data_random_scales_from_data":
16071616
if not K % BLOCK_SIZE == 0:
@@ -1613,17 +1622,18 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
16131622

16141623
if recipe == "mxfp8":
16151624
# Calculate scales based on the inputs
1616-
A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE)
1617-
B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE)
1625+
A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE, recipe)
1626+
B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE, recipe)
16181627
max_val = F8E4M3_MAX_VAL
16191628
min_val = -1 * max_val
16201629
A = (A_ref.reshape(-1, BLOCK_SIZE) / A_scale.reshape(M * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(M, K)
16211630
A = A.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
16221631
B = (B_ref.reshape(-1, BLOCK_SIZE) / B_scale.reshape(N * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(N, K)
16231632
B = B.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
1624-
else: # nvfp4
1625-
A_scale = data_to_nvfp4_scale(A_ref, BLOCK_SIZE)
1626-
B_scale = data_to_nvfp4_scale(B_ref, BLOCK_SIZE)
1633+
else: # nvfp4 # mxfp4
1634+
scale_func = data_to_mx_scale if recipe == "mxfp4" else data_to_nvfp4_scale
1635+
A_scale = scale_func(A_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
1636+
B_scale = scale_func(B_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
16271637
max_val = FP4_MAX_VAL
16281638
min_val = -1 * max_val
16291639

@@ -1634,13 +1644,14 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
16341644
B = B.clamp(min=min_val, max=max_val)
16351645
B = _bfloat16_to_float4_e2m1fn_x2(B)
16361646

1637-
approx_match_sqnr_target = 15.8
1647+
approx_match_sqnr_target = 12.0 if torch.version.hip else 15.8
16381648

16391649
C_ref = A_ref @ B_ref.t()
16401650

16411651
# convert to swizzled format
1642-
A_scale = to_blocked(A_scale)
1643-
B_scale = to_blocked(B_scale)
1652+
if not torch.version.hip:
1653+
A_scale = to_blocked(A_scale)
1654+
B_scale = to_blocked(B_scale)
16441655

16451656
C = torch._scaled_mm(
16461657
A,
@@ -1657,6 +1668,7 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
16571668
sqnr = compute_error(C_ref, C)
16581669
assert sqnr.item() > approx_match_sqnr_target
16591670

1671+
@skipIfRocm
16601672
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM or IS_WINDOWS, mx_skip_msg)
16611673
@parametrize("recipe", ["mxfp8", "nvfp4"])
16621674
def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None:
@@ -1899,6 +1911,7 @@ def test_blockwise_mxfp8_compile(self) -> None:
18991911
)
19001912
torch.testing.assert_close(C, C_ref, atol=0, rtol=0)
19011913

1914+
@skipIfRocm
19021915
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
19031916
def test_blockwise_nvfp4_compile(self) -> None:
19041917

0 commit comments

Comments
 (0)