Skip to content

Commit 68e31e2

Browse files
eqypytorchmergebot
authored andcommitted
[CUDA] Skip pynvml test on platforms that don't have complete support (pytorch#159689)
Pull Request resolved: pytorch#159689 Approved by: https://github.com/msaroufim, https://github.com/Skylion007
1 parent ee1bc3f commit 68e31e2

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

test/test_cuda.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
3636
from torch.testing._internal.common_cuda import (
3737
_create_scaling_case,
38+
HAS_WORKING_NVML,
3839
SM70OrLater,
3940
TEST_CUDNN,
4041
TEST_MULTIGPU,
@@ -4803,6 +4804,7 @@ def test_nvml_get_handler(self):
48034804
def test_temperature(self):
48044805
self.assertTrue(0 <= torch.cuda.temperature() <= 150)
48054806

4807+
@unittest.skipIf(not HAS_WORKING_NVML, "pynvml availble but broken")
48064808
@unittest.skipIf(TEST_WITH_ROCM, "flaky for AMD gpu")
48074809
@unittest.skipIf(not TEST_PYNVML, "pynvml/amdsmi is not available")
48084810
def test_device_memory_used(self):

torch/testing/_internal/common_cuda.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,20 @@ def xfailIfSM120OrLater(func):
376376
def xfailIfDistributedNotSupported(func):
377377
return func if not (IS_MACOS or IS_JETSON) else unittest.expectedFailure(func)
378378

379+
def _check_has_working_nvml() -> bool:
380+
try:
381+
if not torch.cuda.is_available():
382+
return False
383+
import pynvml
384+
torch.cuda.device_memory_used()
385+
return True
386+
except ModuleNotFoundError:
387+
return False
388+
except pynvml.NVMLError_NotSupported:
389+
return False
390+
391+
HAS_WORKING_NVML = _check_has_working_nvml()
392+
379393
# Importing this module should NOT eagerly initialize CUDA
380394
if not CUDA_ALREADY_INITIALIZED_ON_IMPORT:
381395
assert not torch.cuda.is_initialized()

0 commit comments

Comments
 (0)