Skip to content

Commit 608069b

Browse files
okakarpaiupaikov-amd
authored andcommitted
[AUTOGENERATED] [release/2.8] NAVI32 specific fixes (#2467)
Cherry-pick of #2450 --------- Co-authored-by: iupaikov-amd <[email protected]>
1 parent d9d5b96 commit 608069b

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

test/inductor/test_flex_decoding.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
)
2323
from torch.testing import FileCheck
2424
from torch.testing._internal import common_utils
25-
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16
25+
from torch.testing._internal.common_cuda import (
26+
PLATFORM_SUPPORTS_BF16,
27+
PLATFORM_SUPPORTS_FLASH_ATTENTION,
28+
)
2629
from torch.testing._internal.common_device_type import (
2730
flex_attention_supported_platform as supported_platform,
2831
instantiate_device_type_tests,
@@ -1582,6 +1585,7 @@ def mask_mod(b, h, q, kv):
15821585
self.assertEqual(out[:, :, M:, :].sum(), 0)
15831586

15841587
@supported_platform
1588+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
15851589
def test_windowed_no_mask_vs_sdpa(self, device):
15861590
score_mod = _generate_windowed(1000)
15871591
attention = functools.partial(flex_attention, score_mod=score_mod)

test/inductor/test_max_autotune.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,15 @@
3737
)
3838
from torch._inductor.template_heuristics import CUDAConfigHeuristic, GemmConfig
3939
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
40+
from torch.testing._internal.common_device_type import largeTensorTest
4041
from torch.testing._internal.common_utils import (
4142
instantiate_parametrized_tests,
4243
IS_WINDOWS,
4344
parametrize,
4445
TEST_WITH_ROCM,
46+
MI300_ARCH,
47+
runOnRocmArch,
48+
skipIfXpu,
4549
)
4650
from torch.testing._internal.logging_utils import multiple_logs_to_string
4751
from torch.utils._triton import has_triton_tma_device
@@ -54,7 +58,6 @@
5458
from torch._inductor.virtualized import V
5559
from torch.fx.experimental.proxy_tensor import make_fx
5660
from torch.testing import FileCheck
57-
from torch.testing._internal.common_utils import MI300_ARCH, runOnRocmArch, skipIfXpu
5861
from torch.testing._internal.inductor_utils import (
5962
get_func_call,
6063
get_kernel_launch,
@@ -804,6 +807,8 @@ def test_conv_backend(self):
804807

805808
self.assertIn("NoValidChoicesError", str(context.exception))
806809

810+
# Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks
811+
@largeTensorTest("30 GB", device=GPU_TYPE)
807812
def test_non_contiguous_input_mm(self):
808813
"""
809814
Make sure the triton template can work with non-contiguous inputs without crash.
@@ -856,6 +861,8 @@ def f(x, y):
856861
# TODO: fix accuracy failure of the triton template on XPU.
857862
# and enable this test case.
858863
@skipIfXpu
864+
# Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks
865+
@largeTensorTest("30 GB", device=GPU_TYPE)
859866
def test_non_contiguous_input_mm_plus_mm(self):
860867
x1 = rand_strided((50257, 32768), (1, 50304), device=GPU_TYPE)
861868
y1 = rand_strided((32768, 768), (768, 1), device=GPU_TYPE)

0 commit comments

Comments
 (0)