Skip to content

Commit 7ca9a3b

Browse files
eqymansiag05
authored andcommitted
[FP8][cuBLAS][H100] only test fp32 outputs for rowwise _scaled_mm on H100 (pytorch#162022)
only cuBLAS supports float32 output and cuBLAS only supports rowwise for SM 9.0 Intended to land after pytorch#161305 Pull Request resolved: pytorch#162022 Approved by: https://github.com/ngimel
1 parent 6c28b4b commit 7ca9a3b

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

test/test_matmul_cuda.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,22 +1530,34 @@ def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
15301530
x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type)
15311531
y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type)
15321532

1533-
# Calculate actual F8 mm
1534-
out_scaled_mm = mm_float8(
1535-
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
1536-
)
1533+
def test():
1534+
# Calculate actual F8 mm
1535+
out_scaled_mm = mm_float8(
1536+
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
1537+
)
15371538

1538-
# Calculate emulated F8 mm
1539-
out_emulated = mm_float8_emulated(
1540-
x_fp8, x_scales, y_fp8, y_scales, output_dtype
1541-
)
1539+
# Calculate emulated F8 mm
1540+
out_emulated = mm_float8_emulated(
1541+
x_fp8, x_scales, y_fp8, y_scales, output_dtype
1542+
)
15421543

1543-
if base_dtype in {torch.bfloat16, torch.float16}:
1544-
atol, rtol = 7e-2, 7e-2
1545-
else:
1546-
atol, rtol = 2e-3, 2e-3
1544+
if base_dtype in {torch.bfloat16, torch.float16}:
1545+
atol, rtol = 7e-2, 7e-2
1546+
else:
1547+
atol, rtol = 2e-3, 2e-3
15471548

1548-
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
1549+
self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
1550+
1551+
# only cuBLAS supports rowwise with fp32 output and cuBLAS only supports
1552+
# rowwise on SM 9.0
1553+
if torch.cuda.get_device_capability != (9, 0) and output_dtype == torch.float:
1554+
with self.assertRaisesRegex(
1555+
RuntimeError,
1556+
"Only bf16 high precision output types are supported for row-wise scaling."
1557+
):
1558+
test()
1559+
else:
1560+
test()
15491561

15501562
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
15511563
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")

0 commit comments

Comments
 (0)