Skip to content

Commit e7df144

Browse files
authored
[release/2.9] add AMD routine to common_utils.py of test framework (#2778)
added missed test routine to release/2.9 which present in release/2.8 and main branch of upstream. Fixes #SWDEV-555401 Signed-off-by: Artem Kuzmitckii <[email protected]>
1 parent a82efe1 commit e7df144

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

functorch/experimental/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# PyTorch forward-mode is not mature yet
2-
from functorch import functionalize
32
from torch._functorch.apis import chunk_vmap
43
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
54
from torch._functorch.eager_transforms import hessian, jacfwd, jvp
5+
from torch.func import functionalize

torch/testing/_internal/common_utils.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,20 @@
101101
except ImportError:
102102
has_pytest = False
103103

104-
104+
MI350_ARCH = ("gfx950",)
105105
MI300_ARCH = ("gfx942",)
106-
106+
MI200_ARCH = ("gfx90a")
107+
NAVI_ARCH = ("gfx1030", "gfx1100", "gfx1101", "gfx1200", "gfx1201")
108+
NAVI3_ARCH = ("gfx1100", "gfx1101")
109+
NAVI4_ARCH = ("gfx1200", "gfx1201")
110+
111+
def is_navi3_arch():
112+
if torch.cuda.is_available():
113+
prop = torch.cuda.get_device_properties(0)
114+
gfx_arch = prop.gcnArchName.split(":")[0]
115+
if gfx_arch in NAVI3_ARCH:
116+
return True
117+
return False
107118

108119
def freeze_rng_state(*args, **kwargs):
109120
return torch.testing._utils.freeze_rng_state(*args, **kwargs)
@@ -1920,15 +1931,20 @@ def wrapper(*args, **kwargs):
19201931
return dec_fn(func)
19211932
return dec_fn
19221933

1934+
def getRocmArchName(device_index: int = 0):
1935+
return torch.cuda.get_device_properties(device_index).gcnArchName
1936+
1937+
def isRocmArchAnyOf(arch: tuple[str, ...]):
1938+
rocmArch = getRocmArchName()
1939+
return any(x in rocmArch for x in arch)
1940+
19231941
def skipIfRocmArch(arch: tuple[str, ...]):
19241942
def dec_fn(fn):
19251943
@wraps(fn)
19261944
def wrap_fn(self, *args, **kwargs):
1927-
if TEST_WITH_ROCM:
1928-
prop = torch.cuda.get_device_properties(0)
1929-
if prop.gcnArchName.split(":")[0] in arch:
1930-
reason = f"skipIfRocm: test skipped on {arch}"
1931-
raise unittest.SkipTest(reason)
1945+
if TEST_WITH_ROCM and isRocmArchAnyOf(arch):
1946+
reason = f"skipIfRocm: test skipped on {arch}"
1947+
raise unittest.SkipTest(reason)
19321948
return fn(self, *args, **kwargs)
19331949
return wrap_fn
19341950
return dec_fn
@@ -1946,11 +1962,9 @@ def runOnRocmArch(arch: tuple[str, ...]):
19461962
def dec_fn(fn):
19471963
@wraps(fn)
19481964
def wrap_fn(self, *args, **kwargs):
1949-
if TEST_WITH_ROCM:
1950-
prop = torch.cuda.get_device_properties(0)
1951-
if prop.gcnArchName.split(":")[0] not in arch:
1952-
reason = f"skipIfRocm: test only runs on {arch}"
1953-
raise unittest.SkipTest(reason)
1965+
if TEST_WITH_ROCM and not isRocmArchAnyOf(arch):
1966+
reason = f"skipIfRocm: test only runs on {arch}"
1967+
raise unittest.SkipTest(reason)
19541968
return fn(self, *args, **kwargs)
19551969
return wrap_fn
19561970
return dec_fn
@@ -2010,15 +2024,18 @@ def wrapper(*args, **kwargs):
20102024
fn(*args, **kwargs)
20112025
return wrapper
20122026

2027+
def getRocmVersion() -> tuple[int, int]:
2028+
from torch.testing._internal.common_cuda import _get_torch_rocm_version
2029+
rocm_version = _get_torch_rocm_version()
2030+
return (rocm_version[0], rocm_version[1])
2031+
20132032
# Skips a test on CUDA if ROCm is available and its version is lower than requested.
20142033
def skipIfRocmVersionLessThan(version=None):
20152034
def dec_fn(fn):
20162035
@wraps(fn)
20172036
def wrap_fn(self, *args, **kwargs):
20182037
if TEST_WITH_ROCM:
2019-
rocm_version = str(torch.version.hip)
2020-
rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha
2021-
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
2038+
rocm_version_tuple = getRocmVersion()
20222039
if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
20232040
reason = f"ROCm {rocm_version_tuple} is available but {version} required"
20242041
raise unittest.SkipTest(reason)

0 commit comments

Comments
 (0)