|
32 | 32 | from torch.testing import FileCheck |
33 | 33 | from torch.testing._internal import common_utils |
34 | 34 | from torch.testing._internal.common_cuda import ( |
| 35 | + PLATFORM_SUPPORTS_FLASH_ATTENTION, |
| 36 | + PLATFORM_SUPPORTS_FP8, |
| 37 | + PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, |
35 | 38 | SM80OrLater, |
36 | | - SM90OrLater, |
37 | | - PLATFORM_SUPPORTS_FLASH_ATTENTION |
38 | 39 | ) |
39 | 40 | from torch.testing._internal.common_device_type import ( |
40 | 41 | _has_sufficient_memory, |
@@ -1367,7 +1368,12 @@ def forward(self, q, k, v): |
1367 | 1368 | self.check_model(Model(), example_inputs) |
1368 | 1369 |
|
1369 | 1370 | @unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+") |
1370 | | - @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") |
| 1371 | + @unittest.skipIf( |
| 1372 | + # for archs where this isn't lowered to flash attention, the math |
| 1373 | + # backend will be used and it doesn't work for bfloat16 |
| 1374 | + not PLATFORM_SUPPORTS_FLASH_ATTENTION, |
| 1375 | + "Some archs don't support SDPA with bfloat16", |
| 1376 | + ) |
1371 | 1377 | def test_sdpa_2(self): |
1372 | 1378 | class Model(torch.nn.Module): |
1373 | 1379 | def __init__(self) -> None: |
@@ -1620,7 +1626,9 @@ def forward(self, values, repeats, mask, embeddings, x, y, z, lst): |
1620 | 1626 | self.check_model(Repro(), example_inputs, dynamic_shapes=spec) |
1621 | 1627 |
|
1622 | 1628 | @skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet") |
1623 | | - @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") |
| 1629 | + @unittest.skipIf( |
| 1630 | + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support flash SDPA" |
| 1631 | + ) |
1624 | 1632 | def test_fallback_kernel_with_symexpr_output(self): |
1625 | 1633 | if self.device != GPU_TYPE: |
1626 | 1634 | raise unittest.SkipTest("requires GPU") |
@@ -4179,7 +4187,9 @@ def grid(meta): |
4179 | 4187 | dynamic_shapes=dynamic_shapes, |
4180 | 4188 | ) |
4181 | 4189 |
|
4182 | | - @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") |
| 4190 | + @unittest.skipIf( |
| 4191 | + not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA" |
| 4192 | + ) |
4183 | 4193 | def test_scaled_dot_product_efficient_attention(self): |
4184 | 4194 | if self.device != GPU_TYPE: |
4185 | 4195 | raise unittest.SkipTest("requires GPU") |
|
0 commit comments