Skip to content

Commit 0016598

Browse files
authored
[release/2.6] NAVI32 specific fixes - Used general decorator instead of ROCm specific (#2515)
Fixes ROCm/frameworks-internal#12096
1 parent 2e48b21 commit 0016598

File tree

2 files changed

+8
-29
lines changed

2 files changed

+8
-29
lines changed

test/inductor/test_max_autotune.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,18 @@
3030
from torch._inductor.virtualized import V
3131
from torch.fx.experimental.proxy_tensor import make_fx
3232
from torch.testing import FileCheck
33+
from torch.testing._internal.common_device_type import largeTensorTest
3334
from torch.testing._internal.common_utils import (
3435
instantiate_parametrized_tests,
3536
parametrize,
36-
skipIfRocmNotEnoughMemory,
3737
skipIfRocm,
3838
TEST_WITH_ROCM,
3939
)
40-
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
40+
from torch.testing._internal.inductor_utils import (
41+
GPU_TYPE,
42+
HAS_CPU,
43+
HAS_CUDA,
44+
)
4145

4246

4347
torch.set_float32_matmul_precision("high")
@@ -719,7 +723,7 @@ def test_conv_backend(self):
719723
self.assertIn("NoValidChoicesError", str(context.exception))
720724

721725
# Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks
722-
@skipIfRocmNotEnoughMemory(30)
726+
@largeTensorTest("30 GB", device=GPU_TYPE)
723727
def test_non_contiguous_input_mm(self):
724728
"""
725729
Make sure the triton template can work with non-contiguous inputs without crash.
@@ -770,7 +774,7 @@ def f(x, y):
770774
torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2)
771775

772776
# Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks
773-
@skipIfRocmNotEnoughMemory(30)
777+
@largeTensorTest("30 GB", device=GPU_TYPE)
774778
def test_non_contiguous_input_mm_plus_mm(self):
775779
x1 = rand_strided((50257, 32768), (1, 50304), device="cuda")
776780
y1 = rand_strided((32768, 768), (768, 1), device="cuda")

torch/testing/_internal/common_utils.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,31 +1872,6 @@ def wrap_fn(self, *args, **kwargs):
18721872
return wrap_fn
18731873
return dec_fn
18741874

1875-
# Checks if current ROCm device has enough VRAM against the required amount in GB
1876-
def skipIfRocmNotEnoughMemory(required_amount):
1877-
def dec_fn(fn):
1878-
@wraps(fn)
1879-
def wrap_fn(self, *args, **kwargs):
1880-
if TEST_WITH_ROCM:
1881-
device = torch.cuda.current_device()
1882-
props = torch.cuda.get_device_properties(device)
1883-
1884-
total = props.total_memory / (1024 ** 3) # in GB
1885-
# This will probably return 0 because it only counts tensors
1886-
# and doesn't take into account any small supporting allocations
1887-
allocated = torch.cuda.memory_allocated(device) / (1024 ** 3)
1888-
free_global = total - allocated
1889-
1890-
result = free_global > required_amount
1891-
1892-
if not result:
1893-
reason = f"skipIfRocm: Not enough free VRAM on current ROCm device. " \
1894-
f"Available: {free_global:.2f} GB | Required: {required_amount:.2f} GB."
1895-
raise unittest.SkipTest(reason)
1896-
return fn(self, *args, **kwargs)
1897-
return wrap_fn
1898-
return dec_fn
1899-
19001875
def runOnRocm(fn):
19011876
@wraps(fn)
19021877
def wrapper(*args, **kwargs):

0 commit comments

Comments
 (0)