Skip to content

Commit 07f412a

Browse files
[AUTOGENERATED] [release/2.7] NAVI32 specific fixes (#2466)
Cherry-pick of #2450 --------- Co-authored-by: iupaikov-amd <[email protected]>
1 parent 698b58a commit 07f412a

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

test/inductor/test_flex_decoding.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
)
2222
from torch.testing import FileCheck
2323
from torch.testing._internal import common_utils
24-
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16
24+
from torch.testing._internal.common_cuda import (
25+
PLATFORM_SUPPORTS_BF16,
26+
PLATFORM_SUPPORTS_FLASH_ATTENTION,
27+
)
2528
from torch.testing._internal.common_utils import skipIfRocm
2629
from torch.utils._triton import has_triton
2730

@@ -1421,6 +1424,7 @@ def mask_mod(b, h, q, kv):
14211424
self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
14221425

14231426
@supported_platform
1427+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
14241428
def test_windowed_no_mask_vs_sdpa(self):
14251429
score_mod = _generate_windowed(1000)
14261430
attention = functools.partial(flex_attention, score_mod=score_mod)

test/inductor/test_max_autotune.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
TritonTemplateCaller,
2828
)
2929
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
30+
from torch.testing._internal.common_device_type import largeTensorTest
3031
from torch.testing._internal.common_utils import (
3132
instantiate_parametrized_tests,
3233
IS_WINDOWS,
@@ -44,7 +45,12 @@
4445
from torch.fx.experimental.proxy_tensor import make_fx
4546
from torch.testing import FileCheck
4647
from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu
47-
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU
48+
from torch.testing._internal.inductor_utils import (
49+
GPU_TYPE,
50+
HAS_CPU,
51+
HAS_CUDA,
52+
HAS_GPU,
53+
)
4854

4955

5056
torch.set_float32_matmul_precision("high")
@@ -981,6 +987,8 @@ def test_conv_backend(self):
981987

982988
self.assertIn("NoValidChoicesError", str(context.exception))
983989

990+
# Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks
991+
@largeTensorTest("30 GB", device=GPU_TYPE)
984992
def test_non_contiguous_input_mm(self):
985993
"""
986994
Make sure the triton template can work with non-contiguous inputs without crash.
@@ -1033,6 +1041,8 @@ def f(x, y):
10331041
# TODO: fix accuracy failure of the triton template on XPU.
10341042
# and enable this test case.
10351043
@skipIfXpu
1044+
# Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks
1045+
@largeTensorTest("30 GB", device=GPU_TYPE)
10361046
def test_non_contiguous_input_mm_plus_mm(self):
10371047
x1 = rand_strided((50257, 32768), (1, 50304), device=GPU_TYPE)
10381048
y1 = rand_strided((32768, 768), (768, 1), device=GPU_TYPE)

0 commit comments

Comments
 (0)