Skip to content

Commit 0154656

Browse files
authored
[release/2.6] Update SDPA skip logic for Navi (#2280)
Fixes this one: https://ontrack-internal.amd.com/browse/SWDEV-522391
1 parent d1c90a0 commit 0154656

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from torch.testing._internal import common_utils
3131
from torch.testing._internal.common_cuda import (
3232
PLATFORM_SUPPORTS_FLASH_ATTENTION,
33+
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
3334
SM80OrLater,
3435
SM90OrLater,
3536
)
@@ -929,7 +930,10 @@ def forward(self, q, k, v):
929930
@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
930931
@unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
931932
@unittest.skipIf(
932-
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA"
933+
# for archs where this isn't lowered to flash attention, the math
934+
# backend will be used and it doesn't work for bfloat16
935+
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
936+
"Some archs don't support SDPA with bfloat16",
933937
)
934938
def test_sdpa_2(self):
935939
class Model(torch.nn.Module):
@@ -1039,7 +1043,7 @@ def forward(self, x, y):
10391043

10401044
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
10411045
@unittest.skipIf(
1042-
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA"
1046+
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support flash SDPA"
10431047
)
10441048
def test_fallback_kernel_with_symexpr_output(self):
10451049
if self.device != GPU_TYPE:
@@ -3036,7 +3040,7 @@ def grid(meta):
30363040
)
30373041

30383042
@unittest.skipIf(
3039-
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA"
3043+
not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA"
30403044
)
30413045
def test_scaled_dot_product_efficient_attention(self):
30423046
if self.device != GPU_TYPE:

test/inductor/test_torchinductor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10611,6 +10611,9 @@ def fn(q, k, v):
1061110611
)
1061210612

1061310613
@expectedFailureXPU
10614+
@unittest.skipIf(
10615+
not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA"
10616+
)
1061410617
def test_scaled_dot_product_efficient_attention(self):
1061510618
if self.device == "cpu":
1061610619
raise unittest.SkipTest(f"requires {GPU_TYPE}")

0 commit comments

Comments
 (0)