|
31 | 31 | from torch.export.pt2_archive._package import load_pt2 |
32 | 32 | from torch.testing import FileCheck |
33 | 33 | from torch.testing._internal import common_utils |
34 | | -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM80OrLater |
| 34 | +from torch.testing._internal.common_cuda import ( |
| 35 | + SM80OrLater, |
| 36 | + SM90OrLater, |
| 37 | + PLATFORM_SUPPORTS_FLASH_ATTENTION |
| 38 | +) |
35 | 39 | from torch.testing._internal.common_device_type import ( |
36 | 40 | _has_sufficient_memory, |
37 | 41 | skipCUDAIf, |
@@ -1363,6 +1367,7 @@ def forward(self, q, k, v): |
1363 | 1367 | self.check_model(Model(), example_inputs) |
1364 | 1368 |
|
1365 | 1369 | @unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+") |
| 1370 | + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") |
1366 | 1371 | def test_sdpa_2(self): |
1367 | 1372 | class Model(torch.nn.Module): |
1368 | 1373 | def __init__(self) -> None: |
@@ -1615,6 +1620,7 @@ def forward(self, values, repeats, mask, embeddings, x, y, z, lst): |
1615 | 1620 | self.check_model(Repro(), example_inputs, dynamic_shapes=spec) |
1616 | 1621 |
|
1617 | 1622 | @skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet") |
| 1623 | + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") |
1618 | 1624 | def test_fallback_kernel_with_symexpr_output(self): |
1619 | 1625 | if self.device != GPU_TYPE: |
1620 | 1626 | raise unittest.SkipTest("requires GPU") |
@@ -4173,6 +4179,7 @@ def grid(meta): |
4173 | 4179 | dynamic_shapes=dynamic_shapes, |
4174 | 4180 | ) |
4175 | 4181 |
|
| 4182 | + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") |
4176 | 4183 | def test_scaled_dot_product_efficient_attention(self): |
4177 | 4184 | if self.device != GPU_TYPE: |
4178 | 4185 | raise unittest.SkipTest("requires GPU") |
|
0 commit comments