|
33 | 33 |
|
34 | 34 | IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)]) |
35 | 35 |
|
36 | | -def CDNA2OrLater(): |
37 | | - if TEST_WITH_ROCM: |
38 | | - gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName |
39 | | - return any(arch in gcn_arch_name for arch in {"gfx90a", "gfx940", "gfx941", "gfx942"}) |
40 | | - return False |
41 | | - |
42 | | -def evaluate_gfx_arch_exact(matching_arch): |
| 36 | +def evaluate_gfx_arch_within(arch_list): |
43 | 37 | if not torch.cuda.is_available(): |
44 | 38 | return False |
45 | 39 | gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName |
46 | | - arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name) |
47 | | - return arch == matching_arch |
| 40 | + effective_arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name) |
| 41 | + # gcnArchName can be complicated strings like gfx90a:sramecc+:xnack- |
| 42 | + # Hence the matching should be done reversely |
| 43 | + return any(arch in effective_arch for arch in arch_list) |
48 | 44 |
|
49 | | -GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-')) |
50 | | -GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')) |
| 45 | +def CDNA2OrLater(): |
| 46 | + return evaluate_gfx_arch_within(["gfx90a", "gfx942"]) |
51 | 47 |
|
52 | 48 | def evaluate_platform_supports_flash_attention(): |
53 | 49 | if TEST_WITH_ROCM: |
54 | | - return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-') |
| 50 | + arch_list = ["gfx90a", "gfx942", "gfx1100"] |
| 51 | + return evaluate_gfx_arch_within(arch_list) |
55 | 52 | if TEST_CUDA: |
56 | 53 | return not IS_WINDOWS and SM80OrLater |
57 | 54 | return False |
58 | 55 |
|
59 | 56 | def evaluate_platform_supports_efficient_attention(): |
60 | 57 | if TEST_WITH_ROCM: |
61 | | - return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-') |
| 58 | + arch_list = ["gfx90a", "gfx942", "gfx1100"] |
| 59 | + return evaluate_gfx_arch_within(arch_list) |
62 | 60 | if TEST_CUDA: |
63 | 61 | return True |
64 | 62 | return False |
|
0 commit comments