Skip to content

Commit 8fbb36e

Browse files
iupaikov-amdjithunnair-amd
authored andcommitted
Added a decorator to skip tests that require more memory than GPU has
1 parent 66ffd4a commit 8fbb36e

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

test/inductor/test_max_autotune.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@
3333
from torch.testing._internal.common_utils import (
3434
instantiate_parametrized_tests,
3535
parametrize,
36+
skipIfRocmNotEnoughMemory,
3637
skipIfRocm,
37-
skipIfRocmArch,
3838
TEST_WITH_ROCM,
39-
NAVI32_ARCH,
4039
)
4140
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
4241

@@ -719,8 +718,8 @@ def test_conv_backend(self):
719718

720719
self.assertIn("NoValidChoicesError", str(context.exception))
721720

722-
# NAVI32 doesn't have enough VRAM to run all autotune configurations and padding benchmarks
723-
@skipIfRocmArch(NAVI32_ARCH)
721+
# Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks
722+
@skipIfRocmNotEnoughMemory(30)
724723
def test_non_contiguous_input_mm(self):
725724
"""
726725
Make sure the triton template can work with non-contiguous inputs without crash.
@@ -770,8 +769,8 @@ def f(x, y):
770769
act = f(x, y)
771770
torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2)
772771

773-
# NAVI32 doesn't have enough VRAM to run all autotune configurations and padding benchmarks
774-
@skipIfRocmArch(NAVI32_ARCH)
772+
# Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks
773+
@skipIfRocmNotEnoughMemory(30)
775774
def test_non_contiguous_input_mm_plus_mm(self):
776775
x1 = rand_strided((50257, 32768), (1, 50304), device="cuda")
777776
y1 = rand_strided((32768, 768), (768, 1), device="cuda")

torch/testing/_internal/common_utils.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,15 +1872,27 @@ def wrap_fn(self, *args, **kwargs):
18721872
return wrap_fn
18731873
return dec_fn
18741874

1875-
def skipIfRocmArch(arch: tuple[str, ...]):
1875+
# Checks if current ROCm device has enough VRAM against the required amount in GB
1876+
def skipIfRocmNotEnoughMemory(required_amount):
18761877
def dec_fn(fn):
18771878
@wraps(fn)
18781879
def wrap_fn(self, *args, **kwargs):
18791880
if TEST_WITH_ROCM:
1880-
prop = torch.cuda.get_device_properties(0)
1881-
if prop.gcnArchName.split(":")[0] in arch:
1882-
reason = f"skipIfRocm: test skipped on {arch}"
1881+
device = torch.cuda.current_device()
1882+
props = torch.cuda.get_device_properties(device)
1883+
1884+
total = props.total_memory / (1024 ** 3) # in GB
1885+
# This will probably return 0 because it only counts tensors
1886+
# and doesn't take into account any small supporting allocations
1887+
allocated = torch.cuda.memory_allocated(device) / (1024 ** 3)
1888+
free_global = total - allocated
1889+
1890+
result = free_global > required_amount
1891+
1892+
if not result:
1893+
reason = f"skipIfRocm: Not enough free VRAM on current ROCm device. Available {free_global:.2f} GB | Required {required_amount:.2f} GB."
18831894
raise unittest.SkipTest(reason)
1895+
18841896
return fn(self, *args, **kwargs)
18851897
return wrap_fn
18861898
return dec_fn

0 commit comments

Comments
 (0)