Skip to content

Commit eb3431b

Browse files
AmdSampsapragupta
authored andcommitted
[release/2.7] Fix SDPA skip logic (#2281)
fixes https://ontrack-internal.amd.com/browse/SWDEV-522391 for PT 2.7 (cherry picked from commit df38cca)
1 parent fb81400 commit eb3431b

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@
3232
from torch.testing import FileCheck
3333
from torch.testing._internal import common_utils
3434
from torch.testing._internal.common_cuda import (
35+
PLATFORM_SUPPORTS_FLASH_ATTENTION,
36+
PLATFORM_SUPPORTS_FP8,
37+
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
3538
SM80OrLater,
36-
SM90OrLater,
37-
PLATFORM_SUPPORTS_FLASH_ATTENTION
3839
)
3940
from torch.testing._internal.common_device_type import (
4041
_has_sufficient_memory,
@@ -1367,7 +1368,12 @@ def forward(self, q, k, v):
13671368
self.check_model(Model(), example_inputs)
13681369

13691370
@unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
1370-
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
1371+
@unittest.skipIf(
1372+
# for archs where this isn't lowered to flash attention, the math
1373+
# backend will be used and it doesn't work for bfloat16
1374+
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
1375+
"Some archs don't support SDPA with bfloat16",
1376+
)
13711377
def test_sdpa_2(self):
13721378
class Model(torch.nn.Module):
13731379
def __init__(self) -> None:
@@ -1620,7 +1626,9 @@ def forward(self, values, repeats, mask, embeddings, x, y, z, lst):
16201626
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
16211627

16221628
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
1623-
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
1629+
@unittest.skipIf(
1630+
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support flash SDPA"
1631+
)
16241632
def test_fallback_kernel_with_symexpr_output(self):
16251633
if self.device != GPU_TYPE:
16261634
raise unittest.SkipTest("requires GPU")
@@ -4179,7 +4187,9 @@ def grid(meta):
41794187
dynamic_shapes=dynamic_shapes,
41804188
)
41814189

4182-
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
4190+
@unittest.skipIf(
4191+
not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA"
4192+
)
41834193
def test_scaled_dot_product_efficient_attention(self):
41844194
if self.device != GPU_TYPE:
41854195
raise unittest.SkipTest("requires GPU")

test/inductor/test_torchinductor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11538,6 +11538,9 @@ def fn(q, k, v):
1153811538

1153911539
@xfail_if_mps_unimplemented
1154011540
@expectedFailureXPU
11541+
@unittest.skipIf(
11542+
not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA"
11543+
)
1154111544
def test_scaled_dot_product_efficient_attention(self):
1154211545
if self.device == "cpu":
1154311546
raise unittest.SkipTest(f"requires {GPU_TYPE}")

0 commit comments

Comments
 (0)