@@ -882,6 +882,8 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
882882
883883# largest power of 2 representable in `torch.float8_e4m3fn`
884884F8E4M3_LARGEST_POW2 = 8
885+ # largest power of 2 representable in `torch.float4_e2m1fn_x2`
886+ FP4E2M1FN_LARGEST_POW2 = 1.0
885887# max value of `torch.float8_e4m3fn` (448)
886888F8E4M3_MAX_VAL = torch .finfo (torch .float8_e4m3fn ).max
887889# exponent bias of `torch.float8_e8m0fnu`
@@ -890,14 +892,20 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
890892FP4_EBITS , FP4_MBITS = 2 , 1
891893FP4_MAX_VAL = 6.0
892894
893- def data_to_mx_scale (x , block_size ):
895+ def data_to_mx_scale (x , block_size , recipe ):
894896 # simple implementation of https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
895897 # section 6.3, not all edge cases (such as NaN) are handled/tested
898+ if recipe == "mxfp8" :
899+ largest_pow2 = F8E4M3_LARGEST_POW2
900+ elif recipe == "mxfp4" :
901+ largest_pow2 = FP4E2M1FN_LARGEST_POW2
902+ else :
903+ raise ValueError (f"data_to_mx_scale(): Unsupported mx recipe: { recipe } " )
896904 orig_shape = x .shape
897905 x = x .reshape (- 1 , block_size )
898906 max_abs = torch .amax (torch .abs (x ), 1 )
899907 largest_p2_lt_max_abs = torch .floor (torch .log2 (max_abs ))
900- scale_e8m0_unbiased = largest_p2_lt_max_abs - F8E4M3_LARGEST_POW2
908+ scale_e8m0_unbiased = largest_p2_lt_max_abs - largest_pow2
901909 scale_e8m0_unbiased = torch .clamp (scale_e8m0_unbiased , - 1 * F8E8M0_EXP_BIAS , F8E8M0_EXP_BIAS )
902910 scale_e8m0_biased = scale_e8m0_unbiased + F8E8M0_EXP_BIAS
903911 scale_e8m0_biased = scale_e8m0_biased .to (torch .uint8 )
@@ -1446,20 +1454,21 @@ def test_pack_uint4(self):
14461454 (127 , 96 , 1024 ),
14471455 (1025 , 128 , 96 )
14481456 ], name_fn = lambda mkn : f"{ mkn [0 ]} _{ mkn [1 ]} _{ mkn [2 ]} " )
1449- @parametrize ("recipe" , ["mxfp8" , "nvfp4" ])
1450- def test_blockwise_mxfp8_nvfp4_numerics (self , test_case_name , fast_accum , mkn , recipe ) -> None :
1451- if recipe == "nvfp4" and fast_accum :
1452- return unittest .skip ("fast_accum not supported in nvfp4 cublas gemm, skipping" )
1457+ @parametrize ("recipe" , ["mxfp8" , "mxfp4" if torch . version . hip else " nvfp4" ])
1458+ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics (self , test_case_name , fast_accum , mkn , recipe ) -> None :
1459+ if ( recipe == "nvfp4" or recipe == "mxfp4" ) and fast_accum :
1460+ raise unittest .SkipTest ("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping" )
14531461
14541462 device = "cuda"
14551463 M , K , N = mkn
14561464 if torch .version .hip :
14571465 if not (M % 32 == 0 and K % 32 == 0 and N % 32 == 0 ):
14581466 raise unittest .SkipTest ("Matrix dimensions must be multiples of 32 on ROCm, skipping" )
14591467
1460- if recipe == "nvfp4" and K % 32 != 0 :
1461- return unittest .skip ("K must be divisible by 32 for nvfp4 cublas gemm, skipping" )
1468+ if ( recipe == "nvfp4" or recipe == "mxfp4" ) and K % 32 != 0 :
1469+ raise unittest .SkipTest ("K must be divisible by 32 for nvfp4/mxfp4 cublas gemm, skipping" )
14621470
1471+ fp4_scaling_dtype = torch .float8_e8m0fnu if torch .version .hip else torch .float8_e4m3fn
14631472 BLOCK_SIZE = 16 if recipe == "nvfp4" else 32
14641473 require_exact_match = True
14651474 approx_match_sqnr_target = 22.0
@@ -1475,11 +1484,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
14751484 B = B_ref .to (torch .float8_e4m3fn )
14761485 A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch .float8_e8m0fnu )
14771486 B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch .float8_e8m0fnu )
1478- else : # nvfp4
1487+ else : # nvfp4 # mxfp4
14791488 A = _bfloat16_to_float4_e2m1fn_x2 (A_ref )
14801489 B = _bfloat16_to_float4_e2m1fn_x2 (B_ref )
1481- A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch . float8_e4m3fn )
1482- B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch . float8_e4m3fn )
1490+ A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = fp4_scaling_dtype )
1491+ B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = fp4_scaling_dtype )
14831492
14841493 elif test_case_name == "a_ones_b_ones" :
14851494 A_ref = torch .ones (M , K , device = device , dtype = torch .bfloat16 )
@@ -1490,11 +1499,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
14901499 B = B_ref .to (torch .float8_e4m3fn )
14911500 A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch .float8_e8m0fnu )
14921501 B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch .float8_e8m0fnu )
1493- else : # nvfp4
1502+ else : # nvfp4 # mxfp4
14941503 A = _bfloat16_to_float4_e2m1fn_x2 (A_ref )
14951504 B = _bfloat16_to_float4_e2m1fn_x2 (B_ref )
1496- A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch . float8_e4m3fn )
1497- B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch . float8_e4m3fn )
1505+ A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = fp4_scaling_dtype )
1506+ B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = fp4_scaling_dtype )
14981507
14991508 elif test_case_name == "a_ones_modified_b_ones" :
15001509 A_ref = torch .ones (M , K , device = device , dtype = torch .bfloat16 )
@@ -1506,11 +1515,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
15061515 B = B_ref .to (torch .float8_e4m3fn )
15071516 A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch .float8_e8m0fnu )
15081517 B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch .float8_e8m0fnu )
1509- else : # nvfp4
1518+ else : # nvfp4 # mxfp4
15101519 A = _bfloat16_to_float4_e2m1fn_x2 (A_ref )
15111520 B = _bfloat16_to_float4_e2m1fn_x2 (B_ref )
1512- A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch . float8_e4m3fn )
1513- B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch . float8_e4m3fn )
1521+ A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = fp4_scaling_dtype )
1522+ B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = fp4_scaling_dtype )
15141523
15151524 elif test_case_name == "a_ones_b_ones_modified" :
15161525 A_ref = torch .ones (M , K , device = device , dtype = torch .bfloat16 )
@@ -1522,11 +1531,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
15221531 B = B_ref .to (torch .float8_e4m3fn )
15231532 A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch .float8_e8m0fnu )
15241533 B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch .float8_e8m0fnu )
1525- else : # nvfp4
1534+ else : # nvfp4 # mxfp4
15261535 A = _bfloat16_to_float4_e2m1fn_x2 (A_ref )
15271536 B = _bfloat16_to_float4_e2m1fn_x2 (B_ref )
1528- A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch . float8_e4m3fn )
1529- B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch . float8_e4m3fn )
1537+ A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = fp4_scaling_dtype )
1538+ B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = fp4_scaling_dtype )
15301539
15311540 elif test_case_name == "a_scale_modified_b_ones" :
15321541 A_ref = torch .ones (M , K , device = device , dtype = torch .bfloat16 )
@@ -1540,11 +1549,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
15401549 A_ref [1 ][0 :BLOCK_SIZE ] = 4
15411550 A [1 ][0 :BLOCK_SIZE ] = 2
15421551 A_scale [1 ][0 ] = 2
1543- else : # nvfp4
1552+ else : # nvfp4 # mxfp4
15441553 A = _bfloat16_to_float4_e2m1fn_x2 (A_ref )
15451554 B = _bfloat16_to_float4_e2m1fn_x2 (B_ref )
1546- A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch . float8_e4m3fn )
1547- B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch . float8_e4m3fn )
1555+ A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = fp4_scaling_dtype )
1556+ B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = fp4_scaling_dtype )
15481557 A_ref [1 ][0 :BLOCK_SIZE ] = 4
15491558 A .view (torch .uint8 )[1 ][0 :(BLOCK_SIZE // 2 )] = 0b01000100
15501559 A_scale [1 ][0 ] = 2
@@ -1561,11 +1570,11 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
15611570 B_ref [1 ][0 :BLOCK_SIZE ] = 4
15621571 B [1 ][0 :BLOCK_SIZE ] = 2
15631572 B_scale [1 ][0 ] = 2
1564- else : # nvfp4
1573+ else : # nvfp4 # mxfp4
15651574 A = _bfloat16_to_float4_e2m1fn_x2 (A_ref )
15661575 B = _bfloat16_to_float4_e2m1fn_x2 (B_ref )
1567- A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch . float8_e4m3fn )
1568- B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch . float8_e4m3fn )
1576+ A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = fp4_scaling_dtype )
1577+ B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = fp4_scaling_dtype )
15691578 B_ref [1 ][0 :BLOCK_SIZE ] = 4
15701579 B .view (torch .uint8 )[1 ][0 :(BLOCK_SIZE // 2 )] = 0b01000100
15711580 B_scale [1 ][0 ] = 2
@@ -1585,7 +1594,7 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
15851594 B = B_ref .to (torch .float8_e4m3fn )
15861595 A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch .float8_e8m0fnu )
15871596 B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch .float8_e8m0fnu )
1588- else : # nvfp4
1597+ else : # nvfp4 # mxfp4
15891598 # scales all-ones, element data random while being exactly representable in float4_e2m1fn_x2
15901599 # generate integers in [0, 16] and cast to bfloat16
15911600 A_ref = _floatx_unpacked_to_f32 (
@@ -1600,8 +1609,8 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
16001609 ).bfloat16 ()
16011610 A = _bfloat16_to_float4_e2m1fn_x2 (A_ref )
16021611 B = _bfloat16_to_float4_e2m1fn_x2 (B_ref )
1603- A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch . float8_e4m3fn )
1604- B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = torch . float8_e4m3fn )
1612+ A_scale = torch .full ((M , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = fp4_scaling_dtype )
1613+ B_scale = torch .full ((N , ceil_div (K , BLOCK_SIZE )), 1.0 , device = device , dtype = fp4_scaling_dtype )
16051614
16061615 elif test_case_name == "data_random_scales_from_data" :
16071616 if not K % BLOCK_SIZE == 0 :
@@ -1613,17 +1622,18 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
16131622
16141623 if recipe == "mxfp8" :
16151624 # Calculate scales based on the inputs
1616- A_scale = data_to_mx_scale (A_ref , BLOCK_SIZE )
1617- B_scale = data_to_mx_scale (B_ref , BLOCK_SIZE )
1625+ A_scale = data_to_mx_scale (A_ref , BLOCK_SIZE , recipe )
1626+ B_scale = data_to_mx_scale (B_ref , BLOCK_SIZE , recipe )
16181627 max_val = F8E4M3_MAX_VAL
16191628 min_val = - 1 * max_val
16201629 A = (A_ref .reshape (- 1 , BLOCK_SIZE ) / A_scale .reshape (M * ceil_div (K , BLOCK_SIZE ), 1 ).float ()).reshape (M , K )
16211630 A = A .clamp (min = min_val , max = max_val ).to (torch .float8_e4m3fn )
16221631 B = (B_ref .reshape (- 1 , BLOCK_SIZE ) / B_scale .reshape (N * ceil_div (K , BLOCK_SIZE ), 1 ).float ()).reshape (N , K )
16231632 B = B .clamp (min = min_val , max = max_val ).to (torch .float8_e4m3fn )
1624- else : # nvfp4
1625- A_scale = data_to_nvfp4_scale (A_ref , BLOCK_SIZE )
1626- B_scale = data_to_nvfp4_scale (B_ref , BLOCK_SIZE )
1633+ else : # nvfp4 # mxfp4
1634+ scale_func = data_to_mx_scale if recipe == "mxfp4" else data_to_nvfp4_scale
1635+ A_scale = scale_func (A_ref , BLOCK_SIZE , recipe if recipe == "mxfp4" else None )
1636+ B_scale = scale_func (B_ref , BLOCK_SIZE , recipe if recipe == "mxfp4" else None )
16271637 max_val = FP4_MAX_VAL
16281638 min_val = - 1 * max_val
16291639
@@ -1634,13 +1644,14 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
16341644 B = B .clamp (min = min_val , max = max_val )
16351645 B = _bfloat16_to_float4_e2m1fn_x2 (B )
16361646
1637- approx_match_sqnr_target = 15.8
1647+ approx_match_sqnr_target = 12.0 if torch . version . hip else 15.8
16381648
16391649 C_ref = A_ref @ B_ref .t ()
16401650
16411651 # convert to swizzled format
1642- A_scale = to_blocked (A_scale )
1643- B_scale = to_blocked (B_scale )
1652+ if not torch .version .hip :
1653+ A_scale = to_blocked (A_scale )
1654+ B_scale = to_blocked (B_scale )
16441655
16451656 C = torch ._scaled_mm (
16461657 A ,
@@ -1657,6 +1668,7 @@ def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, r
16571668 sqnr = compute_error (C_ref , C )
16581669 assert sqnr .item () > approx_match_sqnr_target
16591670
1671+ @skipIfRocm
16601672 @unittest .skipIf (not PLATFORM_SUPPORTS_MX_GEMM or IS_WINDOWS , mx_skip_msg )
16611673 @parametrize ("recipe" , ["mxfp8" , "nvfp4" ])
16621674 def test_blockwise_mxfp8_nvfp4_error_messages (self , device , recipe ) -> None :
@@ -1899,6 +1911,7 @@ def test_blockwise_mxfp8_compile(self) -> None:
18991911 )
19001912 torch .testing .assert_close (C , C_ref , atol = 0 , rtol = 0 )
19011913
1914+ @skipIfRocm
19021915 @unittest .skipIf (not PLATFORM_SUPPORTS_MX_GEMM , mx_skip_msg )
19031916 def test_blockwise_nvfp4_compile (self ) -> None :
19041917
0 commit comments