Skip to content

Commit 5a980aa

Browse files
authored
[ROCm] testing: enable MEFF/FA unittests for gfx1100 (#1951)
Include gfx1100, and optionally enable gfx1201/gfx950 according to env var TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL Fixes #SWDEV-515529
1 parent cb9541a commit 5a980aa

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

torch/testing/_internal/common_cuda.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,32 +33,30 @@
3333

3434
IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)])
3535

36-
def CDNA2OrLater():
37-
if TEST_WITH_ROCM:
38-
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
39-
return any(arch in gcn_arch_name for arch in {"gfx90a", "gfx940", "gfx941", "gfx942"})
40-
return False
41-
42-
def evaluate_gfx_arch_exact(matching_arch):
36+
def evaluate_gfx_arch_within(arch_list):
4337
if not torch.cuda.is_available():
4438
return False
4539
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
46-
arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
47-
return arch == matching_arch
40+
effective_arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
41+
# gcnArchName can be complicated strings like gfx90a:sramecc+:xnack-
42+
# Hence the matching should be done reversely
43+
return any(arch in effective_arch for arch in arch_list)
4844

49-
GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-'))
50-
GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-'))
45+
def CDNA2OrLater():
46+
return evaluate_gfx_arch_within(["gfx90a", "gfx942"])
5147

5248
def evaluate_platform_supports_flash_attention():
5349
if TEST_WITH_ROCM:
54-
return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
50+
arch_list = ["gfx90a", "gfx942", "gfx1100"]
51+
return evaluate_gfx_arch_within(arch_list)
5552
if TEST_CUDA:
5653
return not IS_WINDOWS and SM80OrLater
5754
return False
5855

5956
def evaluate_platform_supports_efficient_attention():
6057
if TEST_WITH_ROCM:
61-
return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
58+
arch_list = ["gfx90a", "gfx942", "gfx1100"]
59+
return evaluate_gfx_arch_within(arch_list)
6260
if TEST_CUDA:
6361
return True
6462
return False

0 commit comments

Comments
 (0)