Skip to content

Commit 1897a54

Browse files
[AUTOGENERATED] [rocm7.1_internal_testing] NAVI32 specific fixes (#2473)
Cherry-pick of #2450 --------- Co-authored-by: iupaikov-amd <[email protected]>
1 parent 97f8ab5 commit 1897a54

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

test/inductor/test_flex_decoding.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222
)
2323
from torch.testing import FileCheck
2424
from torch.testing._internal import common_utils
25-
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, with_tf32_off
25+
from torch.testing._internal.common_cuda import (
26+
PLATFORM_SUPPORTS_FLASH_ATTENTION,
27+
PLATFORM_SUPPORTS_BF16,
28+
with_tf32_off,
29+
)
2630
from torch.testing._internal.common_device_type import (
2731
flex_attention_supported_platform as supported_platform,
2832
instantiate_device_type_tests,
@@ -1591,6 +1595,7 @@ def mask_mod(b, h, q, kv):
15911595
self.assertEqual(out[:, :, M:, :].sum(), 0)
15921596

15931597
@supported_platform
1598+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
15941599
def test_windowed_no_mask_vs_sdpa(self, device):
15951600
score_mod = _generate_windowed(1000)
15961601
attention = functools.partial(flex_attention, score_mod=score_mod)

test/inductor/test_max_autotune.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@
4545
IS_WINDOWS,
4646
parametrize,
4747
TEST_WITH_ROCM,
48-
NAVI_ARCH,
49-
skipIfRocmArch,
48+
skipIfRocmNotEnoughMemory,
5049
)
5150
from torch.testing._internal.logging_utils import multiple_logs_to_string
5251
from torch.utils._triton import has_triton_tma_device
@@ -819,6 +818,8 @@ def test_conv_backend(self):
819818

820819
self.assertIn("NoValidChoicesError", str(context.exception))
821820

821+
# Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks
822+
@skipIfRocmNotEnoughMemory(30)
822823
def test_non_contiguous_input_mm(self):
823824
"""
824825
Make sure the triton template can work with non-contiguous inputs without crash.
@@ -837,7 +838,6 @@ def f(x, y):
837838
act = f(x, y)
838839
torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2)
839840

840-
@skipIfRocmArch(NAVI_ARCH)
841841
def test_non_contiguous_input_addmm(self):
842842
b = torch.randn((768), dtype=torch.bfloat16, device=GPU_TYPE)
843843
x = rand_strided(
@@ -872,6 +872,8 @@ def f(x, y):
872872
# TODO: fix accuracy failure of the triton template on XPU.
873873
# and enable this test case.
874874
@skipIfXpu
875+
# Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks
876+
@skipIfRocmNotEnoughMemory(30)
875877
def test_non_contiguous_input_mm_plus_mm(self):
876878
x1 = rand_strided((50257, 2048), (1, 50304), device=GPU_TYPE)
877879
y1 = rand_strided((2048, 768), (768, 1), device=GPU_TYPE)

torch/testing/_internal/common_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,6 +1943,31 @@ def wrap_fn(self, *args, **kwargs):
19431943
return wrap_fn
19441944
return dec_fn
19451945

1946+
# Checks if current ROCm device has enough VRAM against the required amount in GB
1947+
def skipIfRocmNotEnoughMemory(required_amount):
1948+
def dec_fn(fn):
1949+
@wraps(fn)
1950+
def wrap_fn(self, *args, **kwargs):
1951+
if TEST_WITH_ROCM:
1952+
device = torch.cuda.current_device()
1953+
props = torch.cuda.get_device_properties(device)
1954+
1955+
total = props.total_memory / (1024 ** 3) # in GB
1956+
# This will probably return 0 because it only counts tensors
1957+
# and doesn't take into account any small supporting allocations
1958+
allocated = torch.cuda.memory_allocated(device) / (1024 ** 3)
1959+
free_global = total - allocated
1960+
1961+
result = free_global > required_amount
1962+
1963+
if not result:
1964+
reason = f"skipIfRocm: Not enough free VRAM on current ROCm device. " \
1965+
f"Available: {free_global:.2f} GB | Required: {required_amount:.2f} GB."
1966+
raise unittest.SkipTest(reason)
1967+
return fn(self, *args, **kwargs)
1968+
return wrap_fn
1969+
return dec_fn
1970+
19461971
def runOnRocm(fn):
19471972
@wraps(fn)
19481973
def wrapper(*args, **kwargs):

0 commit comments

Comments
 (0)