Skip to content

Commit ce7fd34

Browse files
authored
[rocm7.0_internal_testing]fp8: optimize skip rowwise tests (#2476)
Skip based on ROCm version and gfx type. Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
1 parent a928221 commit ce7fd34

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

test/test_matmul_cuda.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
parametrize,
4747
run_tests,
4848
skipIfRocm,
49-
skipIfRocmArch,
49+
skipIfRocmVersionAndArch,
5050
skipIfRocmVersionLessThan,
5151
TEST_CUDA,
5252
TEST_WITH_ROCM,
@@ -908,7 +908,7 @@ def test_float8_scale_fast_accum(self, device) -> None:
908908
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
909909
self.assertEqual(out_fp8, out_fp8_s)
910910

911-
@skipIfRocmArch("gfx950")
911+
@skipIfRocmVersionAndArch((7, 1), "gfx950")
912912
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
913913
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
914914
@parametrize("use_fast_accum", [True, False])
@@ -1014,7 +1014,7 @@ def test_float8_error_messages(self, device) -> None:
10141014
out_dtype=torch.bfloat16,
10151015
)
10161016

1017-
@skipIfRocmArch("gfx950")
1017+
@skipIfRocmVersionAndArch((7, 1), "gfx950")
10181018
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
10191019
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
10201020
@parametrize("base_dtype", [torch.bfloat16])

torch/testing/_internal/common_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1987,6 +1987,23 @@ def wrap_fn(self, *args, **kwargs):
19871987
return wrap_fn
19881988
return dec_fn
19891989

1990+
def skipIfRocmVersionAndArch(version=None, arch=None):
1991+
def dec_fn(fn):
1992+
@wraps(fn)
1993+
def wrap_fn(self, *args, **kwargs):
1994+
if TEST_WITH_ROCM:
1995+
rocm_version = str(torch.version.hip)
1996+
rocm_version = rocm_version.split("-")[0] # ignore git sha
1997+
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
1998+
if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
1999+
prop = torch.cuda.get_device_properties(0)
2000+
if prop.gcnArchName.split(":")[0] in arch:
2001+
reason = f"ROCm {version} and {arch} combination not supported"
2002+
raise unittest.SkipTest(reason)
2003+
return fn(self, *args, **kwargs)
2004+
return wrap_fn
2005+
return dec_fn
2006+
19902007
def skipIfNotMiopenSuggestNHWC(fn):
19912008
@wraps(fn)
19922009
def wrapper(*args, **kwargs):

0 commit comments

Comments
 (0)