|
27 | 27 | from torch.export import Dim, export, export_for_training |
28 | 28 | from torch.testing import FileCheck |
29 | 29 | from torch.testing._internal import common_utils |
30 | | -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM80OrLater |
| 30 | +from torch.testing._internal.common_cuda import ( |
| 31 | + SM80OrLater, |
| 32 | + SM90OrLater, |
| 33 | + PLATFORM_SUPPORTS_FLASH_ATTENTION |
| 34 | +) |
31 | 35 | from torch.testing._internal.common_device_type import ( |
32 | 36 | _has_sufficient_memory, |
33 | 37 | skipCUDAIf, |
@@ -1008,6 +1012,7 @@ def forward(self, q, k, v): |
1008 | 1012 |
|
1009 | 1013 | @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") |
1010 | 1014 | @unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+") |
| 1015 | + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") |
1011 | 1016 | def test_sdpa_2(self): |
1012 | 1017 | class Model(torch.nn.Module): |
1013 | 1018 | def __init__(self) -> None: |
@@ -1114,9 +1119,8 @@ def forward(self, x, y): |
1114 | 1119 | ) |
1115 | 1120 | self.check_model(Repro(), example_inputs) |
1116 | 1121 |
|
1117 | | - @skipIfRocmArch(NAVI32_ARCH) |
1118 | | - # SDPA is not supported on navi32 arch |
1119 | 1122 | @skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet") |
| 1123 | + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") |
1120 | 1124 | def test_fallback_kernel_with_symexpr_output(self): |
1121 | 1125 | if self.device != GPU_TYPE: |
1122 | 1126 | raise unittest.SkipTest("requires GPU") |
@@ -3328,8 +3332,7 @@ def grid(meta): |
3328 | 3332 | dynamic_shapes=dynamic_shapes, |
3329 | 3333 | ) |
3330 | 3334 |
|
3331 | | - @skipIfRocmArch(NAVI32_ARCH) |
3332 | | - # SDPA is not supported on navi32 arch |
| 3335 | + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") |
3333 | 3336 | def test_scaled_dot_product_efficient_attention(self): |
3334 | 3337 | if self.device != GPU_TYPE: |
3335 | 3338 | raise unittest.SkipTest("requires GPU") |
|
0 commit comments