Skip to content

Commit 48630d8

Browse files
okakarpaiupaikov-amd
authored andcommitted
[AUTOGENERATED] [release/2.7] [rocm6.4_internal_testing] Replaced ROCm specific skips to generalized conditions (#2126)
Cherry-pick of #2100 Need to resolve conflicts --------- Co-authored-by: iupaikov-amd <[email protected]> (cherry picked from commit f0c1ce8)
1 parent 23f0b5f commit 48630d8

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@
3131
from torch.export.pt2_archive._package import load_pt2
3232
from torch.testing import FileCheck
3333
from torch.testing._internal import common_utils
34-
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM80OrLater
34+
from torch.testing._internal.common_cuda import (
35+
SM80OrLater,
36+
SM90OrLater,
37+
PLATFORM_SUPPORTS_FLASH_ATTENTION
38+
)
3539
from torch.testing._internal.common_device_type import (
3640
_has_sufficient_memory,
3741
skipCUDAIf,
@@ -1363,6 +1367,7 @@ def forward(self, q, k, v):
13631367
self.check_model(Model(), example_inputs)
13641368

13651369
@unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
1370+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
13661371
def test_sdpa_2(self):
13671372
class Model(torch.nn.Module):
13681373
def __init__(self) -> None:
@@ -1615,6 +1620,7 @@ def forward(self, values, repeats, mask, embeddings, x, y, z, lst):
16151620
self.check_model(Repro(), example_inputs, dynamic_shapes=spec)
16161621

16171622
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
1623+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
16181624
def test_fallback_kernel_with_symexpr_output(self):
16191625
if self.device != GPU_TYPE:
16201626
raise unittest.SkipTest("requires GPU")
@@ -4173,6 +4179,7 @@ def grid(meta):
41734179
dynamic_shapes=dynamic_shapes,
41744180
)
41754181

4182+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
41764183
def test_scaled_dot_product_efficient_attention(self):
41774184
if self.device != GPU_TYPE:
41784185
raise unittest.SkipTest("requires GPU")

0 commit comments

Comments
 (0)