101101except ImportError :
102102 has_pytest = False
103103
104-
104+ MI350_ARCH = ( "gfx950" ,)
105105MI300_ARCH = ("gfx942" ,)
106-
106+ MI200_ARCH = ("gfx90a" )
107+ NAVI_ARCH = ("gfx1030" , "gfx1100" , "gfx1101" , "gfx1200" , "gfx1201" )
108+ NAVI3_ARCH = ("gfx1100" , "gfx1101" )
109+ NAVI4_ARCH = ("gfx1200" , "gfx1201" )
110+
111+ def is_navi3_arch ():
112+ if torch .cuda .is_available ():
113+ prop = torch .cuda .get_device_properties (0 )
114+ gfx_arch = prop .gcnArchName .split (":" )[0 ]
115+ if gfx_arch in NAVI3_ARCH :
116+ return True
117+ return False
107118
108119def freeze_rng_state (* args , ** kwargs ):
109120 return torch .testing ._utils .freeze_rng_state (* args , ** kwargs )
@@ -1920,15 +1931,20 @@ def wrapper(*args, **kwargs):
19201931 return dec_fn (func )
19211932 return dec_fn
19221933
1934+ def getRocmArchName (device_index : int = 0 ):
1935+ return torch .cuda .get_device_properties (device_index ).gcnArchName
1936+
1937+ def isRocmArchAnyOf (arch : tuple [str , ...]):
1938+ rocmArch = getRocmArchName ()
1939+ return any (x in rocmArch for x in arch )
1940+
19231941def skipIfRocmArch (arch : tuple [str , ...]):
19241942 def dec_fn (fn ):
19251943 @wraps (fn )
19261944 def wrap_fn (self , * args , ** kwargs ):
1927- if TEST_WITH_ROCM :
1928- prop = torch .cuda .get_device_properties (0 )
1929- if prop .gcnArchName .split (":" )[0 ] in arch :
1930- reason = f"skipIfRocm: test skipped on { arch } "
1931- raise unittest .SkipTest (reason )
1945+ if TEST_WITH_ROCM and isRocmArchAnyOf (arch ):
1946+ reason = f"skipIfRocm: test skipped on { arch } "
1947+ raise unittest .SkipTest (reason )
19321948 return fn (self , * args , ** kwargs )
19331949 return wrap_fn
19341950 return dec_fn
@@ -1946,11 +1962,9 @@ def runOnRocmArch(arch: tuple[str, ...]):
19461962 def dec_fn (fn ):
19471963 @wraps (fn )
19481964 def wrap_fn (self , * args , ** kwargs ):
1949- if TEST_WITH_ROCM :
1950- prop = torch .cuda .get_device_properties (0 )
1951- if prop .gcnArchName .split (":" )[0 ] not in arch :
1952- reason = f"skipIfRocm: test only runs on { arch } "
1953- raise unittest .SkipTest (reason )
1965+ if TEST_WITH_ROCM and not isRocmArchAnyOf (arch ):
1966+ reason = f"skipIfRocm: test only runs on { arch } "
1967+ raise unittest .SkipTest (reason )
19541968 return fn (self , * args , ** kwargs )
19551969 return wrap_fn
19561970 return dec_fn
@@ -2010,15 +2024,18 @@ def wrapper(*args, **kwargs):
20102024 fn (* args , ** kwargs )
20112025 return wrapper
20122026
2027+ def getRocmVersion () -> tuple [int , int ]:
2028+ from torch .testing ._internal .common_cuda import _get_torch_rocm_version
2029+ rocm_version = _get_torch_rocm_version ()
2030+ return (rocm_version [0 ], rocm_version [1 ])
2031+
20132032# Skips a test on CUDA if ROCm is available and its version is lower than requested.
20142033def skipIfRocmVersionLessThan (version = None ):
20152034 def dec_fn (fn ):
20162035 @wraps (fn )
20172036 def wrap_fn (self , * args , ** kwargs ):
20182037 if TEST_WITH_ROCM :
2019- rocm_version = str (torch .version .hip )
2020- rocm_version = rocm_version .split ("-" , maxsplit = 1 )[0 ] # ignore git sha
2021- rocm_version_tuple = tuple (int (x ) for x in rocm_version .split ("." ))
2038+ rocm_version_tuple = getRocmVersion ()
20222039 if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple (version ):
20232040 reason = f"ROCm { rocm_version_tuple } is available but { version } required"
20242041 raise unittest .SkipTest (reason )
0 commit comments