|
30 | 30 | from torch.testing._internal import common_utils |
31 | 31 | from torch.testing._internal.common_cuda import ( |
32 | 32 | PLATFORM_SUPPORTS_FLASH_ATTENTION, |
| 33 | + PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, |
33 | 34 | SM80OrLater, |
34 | 35 | SM90OrLater, |
35 | 36 | ) |
@@ -929,7 +930,10 @@ def forward(self, q, k, v): |
929 | 930 | @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") |
930 | 931 | @unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+") |
931 | 932 | @unittest.skipIf( |
932 | | - not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA" |
| 933 | + # for archs where this isn't lowered to flash attention, the math |
| 934 | + # backend will be used and it doesn't work for bfloat16 |
| 935 | + not PLATFORM_SUPPORTS_FLASH_ATTENTION, |
| 936 | + "Some archs don't support SDPA with bfloat16", |
933 | 937 | ) |
934 | 938 | def test_sdpa_2(self): |
935 | 939 | class Model(torch.nn.Module): |
@@ -1039,7 +1043,7 @@ def forward(self, x, y): |
1039 | 1043 |
|
1040 | 1044 | @skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet") |
1041 | 1045 | @unittest.skipIf( |
1042 | | - not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA" |
| 1046 | + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support flash SDPA" |
1043 | 1047 | ) |
1044 | 1048 | def test_fallback_kernel_with_symexpr_output(self): |
1045 | 1049 | if self.device != GPU_TYPE: |
@@ -3036,7 +3040,7 @@ def grid(meta): |
3036 | 3040 | ) |
3037 | 3041 |
|
3038 | 3042 | @unittest.skipIf( |
3039 | | - not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA" |
| 3043 | + not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA" |
3040 | 3044 | ) |
3041 | 3045 | def test_scaled_dot_product_efficient_attention(self): |
3042 | 3046 | if self.device != GPU_TYPE: |
|
0 commit comments