Skip to content

Commit 086e2c2

Browse files
Aidyn-Apytorchmergebot
authored andcommitted
[TEST][ATen][CUDA] Skip row-wise scaled matrix mmultiplication tests on sm_120+ (pytorch#152814)
The float8 row-wise scaled matmuls are not supported on Blackwell yet. This PR adds skips to those tests to decrease the noise on `sm_120+` machines. Pull Request resolved: pytorch#152814 Approved by: https://github.com/eqy, https://github.com/Skylion007
1 parent 4b8b7c7 commit 086e2c2

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

test/test_matmul_cuda.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
SM89OrLater,
2525
SM90OrLater,
2626
xfailIfSM100OrLater,
27+
xfailIfSM120OrLater,
2728
_get_torch_cuda_version,
2829
PLATFORM_SUPPORTS_FP8,
2930
PLATFORM_SUPPORTS_MX_GEMM,
@@ -1012,8 +1013,9 @@ def test_float8_scale_fast_accum(self, device) -> None:
10121013
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
10131014
self.assertEqual(out_fp8, out_fp8_s)
10141015

1016+
@xfailIfSM120OrLater
10151017
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
1016-
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
1018+
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
10171019
@parametrize("use_fast_accum", [True, False])
10181020
def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:
10191021
M, K, N = (1024, 512, 2048)
@@ -1117,8 +1119,9 @@ def test_float8_error_messages(self, device) -> None:
11171119
out_dtype=torch.bfloat16,
11181120
)
11191121

1122+
@xfailIfSM120OrLater
11201123
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
1121-
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
1124+
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
11221125
@parametrize("base_dtype", [torch.bfloat16])
11231126
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
11241127
torch.manual_seed(42)

torch/testing/_internal/common_cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
SM89OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9))
3434
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
3535
SM100OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0))
36+
SM120OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (12, 0))
3637

3738
IS_THOR = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 10
3839
and torch.cuda.get_device_capability()[1] > 0)
@@ -335,6 +336,9 @@ def xfailIfSM89(func):
335336
def xfailIfSM100OrLater(func):
336337
return func if not SM100OrLater else unittest.expectedFailure(func)
337338

339+
def xfailIfSM120OrLater(func):
340+
return func if not SM120OrLater else unittest.expectedFailure(func)
341+
338342
def xfailIfDistributedNotSupported(func):
339343
return func if not (IS_MACOS or IS_JETSON) else unittest.expectedFailure(func)
340344

0 commit comments

Comments
 (0)