Skip to content

Commit d2d9708

Browse files
authored
[release/2.8] fp8: skip rowwise tests (#2477)
fp8 rowwise scaling is not supported on ROCm 7.0 w/ gfx950, works on mainline. Skip the test for now. Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
1 parent eb47158 commit d2d9708

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

test/test_matmul_cuda.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
parametrize,
4747
run_tests,
4848
skipIfRocm,
49+
skipIfRocmVersionAndArch,
4950
skipIfRocmVersionLessThan,
5051
TEST_CUDA,
5152
TEST_WITH_ROCM,
@@ -1197,6 +1198,7 @@ def test_float8_scale_fast_accum(self, device) -> None:
11971198
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
11981199
self.assertEqual(out_fp8, out_fp8_s)
11991200

1201+
@skipIfRocmVersionAndArch((7, 1), "gfx950")
12001202
@onlyCUDA
12011203
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
12021204
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
@@ -1304,6 +1306,7 @@ def test_float8_error_messages(self, device) -> None:
13041306
out_dtype=torch.bfloat16,
13051307
)
13061308

1309+
@skipIfRocmVersionAndArch((7, 1), "gfx950")
13071310
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
13081311
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
13091312
@parametrize("base_dtype", [torch.bfloat16])

torch/testing/_internal/common_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2024,6 +2024,23 @@ def wrap_fn(self, *args, **kwargs):
20242024
return wrap_fn
20252025
return dec_fn
20262026

2027+
def skipIfRocmVersionAndArch(version=None, arch=None):
2028+
def dec_fn(fn):
2029+
@wraps(fn)
2030+
def wrap_fn(self, *args, **kwargs):
2031+
if TEST_WITH_ROCM:
2032+
rocm_version = str(torch.version.hip)
2033+
rocm_version = rocm_version.split("-")[0] # ignore git sha
2034+
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
2035+
if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
2036+
prop = torch.cuda.get_device_properties(0)
2037+
if prop.gcnArchName.split(":")[0] in arch:
2038+
reason = f"ROCm {version} and {arch} combination not supported"
2039+
raise unittest.SkipTest(reason)
2040+
return fn(self, *args, **kwargs)
2041+
return wrap_fn
2042+
return dec_fn
2043+
20272044
def skipIfNotMiopenSuggestNHWC(fn):
20282045
@wraps(fn)
20292046
def wrapper(*args, **kwargs):

0 commit comments

Comments
 (0)