Skip to content

Commit f7bd001

Browse files
eqydsashidh
authored andcommitted
[FP8][cuBLAS][H100] only test fp32 outputs for rowwise _scaled_mm on H100 (pytorch#162022)
only cuBLAS supports float32 output and cuBLAS only supports rowwise for SM 9.0 Intended to land after pytorch#161305 Pull Request resolved: pytorch#162022 Approved by: https://github.com/ngimel
1 parent 5c13dc4 commit f7bd001

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

test/test_matmul_cuda.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,22 +1548,34 @@ def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
15481548
x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type)
15491549
y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type)
15501550

1551-
# Calculate actual F8 mm
1552-
out_scaled_mm = mm_float8(
1553-
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
1554-
)
1551+
def test():
1552+
# Calculate actual F8 mm
1553+
out_scaled_mm = mm_float8(
1554+
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
1555+
)
15551556

1556-
# Calculate emulated F8 mm
1557-
out_emulated = mm_float8_emulated(
1558-
x_fp8, x_scales, y_fp8, y_scales, output_dtype
1559-
)
1557+
# Calculate emulated F8 mm
1558+
out_emulated = mm_float8_emulated(
1559+
x_fp8, x_scales, y_fp8, y_scales, output_dtype
1560+
)
15601561

1561-
if base_dtype in {torch.bfloat16, torch.float16}:
1562-
atol, rtol = 7e-2, 7e-2
1563-
else:
1564-
atol, rtol = 2e-3, 2e-3
1562+
if base_dtype in {torch.bfloat16, torch.float16}:
1563+
atol, rtol = 7e-2, 7e-2
1564+
else:
1565+
atol, rtol = 2e-3, 2e-3
15651566

1566-
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
1567+
self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
1568+
1569+
# only cuBLAS supports rowwise with fp32 output and cuBLAS only supports
1570+
# rowwise on SM 9.0
1571+
if torch.cuda.get_device_capability != (9, 0) and output_dtype == torch.float:
1572+
with self.assertRaisesRegex(
1573+
RuntimeError,
1574+
"Only bf16 high precision output types are supported for row-wise scaling."
1575+
):
1576+
test()
1577+
else:
1578+
test()
15671579

15681580
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
15691581
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")

0 commit comments

Comments
 (0)