Skip to content

Commit f0c1ce8

Browse files
[AUTOGENERATED] [release/2.7] [rocm6.4_internal_testing] Replaced ROCm specific skips to generalized conditions (#2126)
Cherry-pick of #2100 Need to resolve conflicts --------- Co-authored-by: iupaikov-amd <[email protected]>
1 parent 77a7b6c commit f0c1ce8

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727
from torch.export import Dim, export, export_for_training
2828
from torch.testing import FileCheck
2929
from torch.testing._internal import common_utils
30-
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM80OrLater
30+
from torch.testing._internal.common_cuda import (
31+
SM80OrLater,
32+
SM90OrLater,
33+
PLATFORM_SUPPORTS_FLASH_ATTENTION
34+
)
3135
from torch.testing._internal.common_device_type import (
3236
_has_sufficient_memory,
3337
skipCUDAIf,
@@ -1008,6 +1012,7 @@ def forward(self, q, k, v):
10081012

10091013
@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
10101014
@unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
1015+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
10111016
def test_sdpa_2(self):
10121017
class Model(torch.nn.Module):
10131018
def __init__(self) -> None:
@@ -1114,9 +1119,8 @@ def forward(self, x, y):
11141119
)
11151120
self.check_model(Repro(), example_inputs)
11161121

1117-
@skipIfRocmArch(NAVI32_ARCH)
1118-
# SDPA is not supported on navi32 arch
11191122
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
1123+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
11201124
def test_fallback_kernel_with_symexpr_output(self):
11211125
if self.device != GPU_TYPE:
11221126
raise unittest.SkipTest("requires GPU")
@@ -3328,8 +3332,7 @@ def grid(meta):
33283332
dynamic_shapes=dynamic_shapes,
33293333
)
33303334

3331-
@skipIfRocmArch(NAVI32_ARCH)
3332-
# SDPA is not supported on navi32 arch
3335+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
33333336
def test_scaled_dot_product_efficient_attention(self):
33343337
if self.device != GPU_TYPE:
33353338
raise unittest.SkipTest("requires GPU")

0 commit comments

Comments
 (0)