@@ -1530,22 +1530,34 @@ def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
1530
1530
x_fp8 = to_fp8_saturated (x * x_scales , e4m3_type )
1531
1531
y_fp8 = to_fp8_saturated (y * y_scales , e4m3_type )
1532
1532
1533
- # Calculate actual F8 mm
1534
- out_scaled_mm = mm_float8 (
1535
- x_fp8 , y_fp8 , a_scale = x_scales , b_scale = y_scales , output_dtype = output_dtype
1536
- )
1533
+ def test ():
1534
+ # Calculate actual F8 mm
1535
+ out_scaled_mm = mm_float8 (
1536
+ x_fp8 , y_fp8 , a_scale = x_scales , b_scale = y_scales , output_dtype = output_dtype
1537
+ )
1537
1538
1538
- # Calculate emulated F8 mm
1539
- out_emulated = mm_float8_emulated (
1540
- x_fp8 , x_scales , y_fp8 , y_scales , output_dtype
1541
- )
1539
+ # Calculate emulated F8 mm
1540
+ out_emulated = mm_float8_emulated (
1541
+ x_fp8 , x_scales , y_fp8 , y_scales , output_dtype
1542
+ )
1542
1543
1543
- if base_dtype in {torch .bfloat16 , torch .float16 }:
1544
- atol , rtol = 7e-2 , 7e-2
1545
- else :
1546
- atol , rtol = 2e-3 , 2e-3
1544
+ if base_dtype in {torch .bfloat16 , torch .float16 }:
1545
+ atol , rtol = 7e-2 , 7e-2
1546
+ else :
1547
+ atol , rtol = 2e-3 , 2e-3
1547
1548
1548
- torch .testing .assert_close (out_scaled_mm , out_emulated , atol = atol , rtol = rtol )
1549
+ self .assertEqual (out_scaled_mm , out_emulated , atol = atol , rtol = rtol )
1550
+
1551
+ # only cuBLAS supports rowwise with fp32 output and cuBLAS only supports
1552
+ # rowwise on SM 9.0
1553
+ if torch .cuda .get_device_capability != (9 , 0 ) and output_dtype == torch .float :
1554
+ with self .assertRaisesRegex (
1555
+ RuntimeError ,
1556
+ "Only bf16 high precision output types are supported for row-wise scaling."
1557
+ ):
1558
+ test ()
1559
+ else :
1560
+ test ()
1549
1561
1550
1562
@unittest .skipIf (not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS , f8_msg )
1551
1563
@unittest .skipIf (not IS_SM90 , "cuBLAS blockwise scaling requires sm90+" )
0 commit comments