From 011ed3e81c6c95d531f5a33d4a339848084c9184 Mon Sep 17 00:00:00 2001 From: Artem Kuzmitckii Date: Tue, 4 Nov 2025 11:53:51 +0000 Subject: [PATCH] [release/2.9] add AMD routine to common_utils.py of test framework Signed-off-by: Artem Kuzmitckii --- functorch/experimental/__init__.py | 2 +- torch/testing/_internal/common_utils.py | 47 +++++++++++++++++-------- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/functorch/experimental/__init__.py b/functorch/experimental/__init__.py index 3941f6d96e1f6..0500fc2c29d35 100644 --- a/functorch/experimental/__init__.py +++ b/functorch/experimental/__init__.py @@ -1,5 +1,5 @@ # PyTorch forward-mode is not mature yet -from functorch import functionalize from torch._functorch.apis import chunk_vmap from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_ from torch._functorch.eager_transforms import hessian, jacfwd, jvp +from torch.func import functionalize diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index bfc568bc14645..38dc910f595e4 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -101,9 +101,20 @@ except ImportError: has_pytest = False - +MI350_ARCH = ("gfx950",) MI300_ARCH = ("gfx942",) - +MI200_ARCH = ("gfx90a") +NAVI_ARCH = ("gfx1030", "gfx1100", "gfx1101", "gfx1200", "gfx1201") +NAVI3_ARCH = ("gfx1100", "gfx1101") +NAVI4_ARCH = ("gfx1200", "gfx1201") + +def is_navi3_arch(): + if torch.cuda.is_available(): + prop = torch.cuda.get_device_properties(0) + gfx_arch = prop.gcnArchName.split(":")[0] + if gfx_arch in NAVI3_ARCH: + return True + return False def freeze_rng_state(*args, **kwargs): return torch.testing._utils.freeze_rng_state(*args, **kwargs) @@ -1920,15 +1931,20 @@ def wrapper(*args, **kwargs): return dec_fn(func) return dec_fn +def getRocmArchName(device_index: int = 0): + return torch.cuda.get_device_properties(device_index).gcnArchName + +def isRocmArchAnyOf(arch: tuple[str, ...]): + rocmArch = getRocmArchName() + return any(x in rocmArch for x in arch) + def skipIfRocmArch(arch: tuple[str, ...]): def dec_fn(fn): @wraps(fn) def wrap_fn(self, *args, **kwargs): - if TEST_WITH_ROCM: - prop = torch.cuda.get_device_properties(0) - if prop.gcnArchName.split(":")[0] in arch: - reason = f"skipIfRocm: test skipped on {arch}" - raise unittest.SkipTest(reason) + if TEST_WITH_ROCM and isRocmArchAnyOf(arch): + reason = f"skipIfRocm: test skipped on {arch}" + raise unittest.SkipTest(reason) return fn(self, *args, **kwargs) return wrap_fn return dec_fn @@ -1946,11 +1962,9 @@ def runOnRocmArch(arch: tuple[str, ...]): def dec_fn(fn): @wraps(fn) def wrap_fn(self, *args, **kwargs): - if TEST_WITH_ROCM: - prop = torch.cuda.get_device_properties(0) - if prop.gcnArchName.split(":")[0] not in arch: - reason = f"skipIfRocm: test only runs on {arch}" - raise unittest.SkipTest(reason) + if TEST_WITH_ROCM and not isRocmArchAnyOf(arch): + reason = f"skipIfRocm: test only runs on {arch}" + raise unittest.SkipTest(reason) return fn(self, *args, **kwargs) return wrap_fn return dec_fn @@ -2010,15 +2024,18 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper +def getRocmVersion() -> tuple[int, int]: + from torch.testing._internal.common_cuda import _get_torch_rocm_version + rocm_version = _get_torch_rocm_version() + return (rocm_version[0], rocm_version[1]) + # Skips a test on CUDA if ROCm is available and its version is lower than requested. def skipIfRocmVersionLessThan(version=None): def dec_fn(fn): @wraps(fn) def wrap_fn(self, *args, **kwargs): if TEST_WITH_ROCM: - rocm_version = str(torch.version.hip) - rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha - rocm_version_tuple = tuple(int(x) for x in rocm_version.split(".")) + rocm_version_tuple = getRocmVersion() if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version): reason = f"ROCm {rocm_version_tuple} is available but {version} required" raise unittest.SkipTest(reason)