Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion functorch/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# PyTorch forward-mode is not mature yet
from functorch import functionalize
from torch._functorch.apis import chunk_vmap
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
from torch._functorch.eager_transforms import hessian, jacfwd, jvp
from torch.func import functionalize
47 changes: 32 additions & 15 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,20 @@
except ImportError:
has_pytest = False


MI350_ARCH = ("gfx950",)
MI300_ARCH = ("gfx942",)

MI200_ARCH = ("gfx90a")
NAVI_ARCH = ("gfx1030", "gfx1100", "gfx1101", "gfx1200", "gfx1201")
NAVI3_ARCH = ("gfx1100", "gfx1101")
NAVI4_ARCH = ("gfx1200", "gfx1201")

def is_navi3_arch():
if torch.cuda.is_available():
prop = torch.cuda.get_device_properties(0)
gfx_arch = prop.gcnArchName.split(":")[0]
if gfx_arch in NAVI3_ARCH:
return True
return False

def freeze_rng_state(*args, **kwargs):
return torch.testing._utils.freeze_rng_state(*args, **kwargs)
Expand Down Expand Up @@ -1920,15 +1931,20 @@ def wrapper(*args, **kwargs):
return dec_fn(func)
return dec_fn

def getRocmArchName(device_index: int = 0):
return torch.cuda.get_device_properties(device_index).gcnArchName

def isRocmArchAnyOf(arch: tuple[str, ...]):
rocmArch = getRocmArchName()
return any(x in rocmArch for x in arch)

def skipIfRocmArch(arch: tuple[str, ...]):
def dec_fn(fn):
@wraps(fn)
def wrap_fn(self, *args, **kwargs):
if TEST_WITH_ROCM:
prop = torch.cuda.get_device_properties(0)
if prop.gcnArchName.split(":")[0] in arch:
reason = f"skipIfRocm: test skipped on {arch}"
raise unittest.SkipTest(reason)
if TEST_WITH_ROCM and isRocmArchAnyOf(arch):
reason = f"skipIfRocm: test skipped on {arch}"
raise unittest.SkipTest(reason)
return fn(self, *args, **kwargs)
return wrap_fn
return dec_fn
Expand All @@ -1946,11 +1962,9 @@ def runOnRocmArch(arch: tuple[str, ...]):
def dec_fn(fn):
@wraps(fn)
def wrap_fn(self, *args, **kwargs):
if TEST_WITH_ROCM:
prop = torch.cuda.get_device_properties(0)
if prop.gcnArchName.split(":")[0] not in arch:
reason = f"skipIfRocm: test only runs on {arch}"
raise unittest.SkipTest(reason)
if TEST_WITH_ROCM and not isRocmArchAnyOf(arch):
reason = f"skipIfRocm: test only runs on {arch}"
raise unittest.SkipTest(reason)
return fn(self, *args, **kwargs)
return wrap_fn
return dec_fn
Expand Down Expand Up @@ -2010,15 +2024,18 @@ def wrapper(*args, **kwargs):
fn(*args, **kwargs)
return wrapper

def getRocmVersion() -> tuple[int, int]:
from torch.testing._internal.common_cuda import _get_torch_rocm_version
rocm_version = _get_torch_rocm_version()
return (rocm_version[0], rocm_version[1])

# Skips a test on CUDA if ROCm is available and its version is lower than requested.
def skipIfRocmVersionLessThan(version=None):
def dec_fn(fn):
@wraps(fn)
def wrap_fn(self, *args, **kwargs):
if TEST_WITH_ROCM:
rocm_version = str(torch.version.hip)
rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
rocm_version_tuple = getRocmVersion()
if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
reason = f"ROCm {rocm_version_tuple} is available but {version} required"
raise unittest.SkipTest(reason)
Expand Down