@@ -1548,22 +1548,34 @@ def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
1548
1548
x_fp8 = to_fp8_saturated (x * x_scales , e4m3_type )
1549
1549
y_fp8 = to_fp8_saturated (y * y_scales , e4m3_type )
1550
1550
1551
- # Calculate actual F8 mm
1552
- out_scaled_mm = mm_float8 (
1553
- x_fp8 , y_fp8 , a_scale = x_scales , b_scale = y_scales , output_dtype = output_dtype
1554
- )
1551
+ def test ():
1552
+ # Calculate actual F8 mm
1553
+ out_scaled_mm = mm_float8 (
1554
+ x_fp8 , y_fp8 , a_scale = x_scales , b_scale = y_scales , output_dtype = output_dtype
1555
+ )
1555
1556
1556
- # Calculate emulated F8 mm
1557
- out_emulated = mm_float8_emulated (
1558
- x_fp8 , x_scales , y_fp8 , y_scales , output_dtype
1559
- )
1557
+ # Calculate emulated F8 mm
1558
+ out_emulated = mm_float8_emulated (
1559
+ x_fp8 , x_scales , y_fp8 , y_scales , output_dtype
1560
+ )
1560
1561
1561
- if base_dtype in {torch .bfloat16 , torch .float16 }:
1562
- atol , rtol = 7e-2 , 7e-2
1563
- else :
1564
- atol , rtol = 2e-3 , 2e-3
1562
+ if base_dtype in {torch .bfloat16 , torch .float16 }:
1563
+ atol , rtol = 7e-2 , 7e-2
1564
+ else :
1565
+ atol , rtol = 2e-3 , 2e-3
1565
1566
1566
- torch .testing .assert_close (out_scaled_mm , out_emulated , atol = atol , rtol = rtol )
1567
+ self .assertEqual (out_scaled_mm , out_emulated , atol = atol , rtol = rtol )
1568
+
1569
+ # only cuBLAS supports rowwise with fp32 output and cuBLAS only supports
1570
+ # rowwise on SM 9.0
1571
+ if torch .cuda .get_device_capability != (9 , 0 ) and output_dtype == torch .float :
1572
+ with self .assertRaisesRegex (
1573
+ RuntimeError ,
1574
+ "Only bf16 high precision output types are supported for row-wise scaling."
1575
+ ):
1576
+ test ()
1577
+ else :
1578
+ test ()
1567
1579
1568
1580
@unittest .skipIf (not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS , f8_msg )
1569
1581
@unittest .skipIf (not IS_SM90 , "cuBLAS blockwise scaling requires sm90+" )
0 commit comments