Skip to content

Commit 944ad69

Browse files
committed
fix tests.
1 parent 8cb809f commit 944ad69

File tree

4 files changed

+32
-0
lines changed

4 files changed

+32
-0
lines changed

tests/tests_pytorch/accelerators/test_gpu.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,17 @@ def test_gpu_availability():
6868
def test_warning_if_gpus_not_used(cuda_count_1):
6969
with pytest.warns(UserWarning, match="GPU available but not used"):
7070
Trainer(accelerator="cpu")
71+
72+
73+
@RunIf(min_cuda_gpus=1)
74+
def test_gpu_device_name():
75+
for i in range(torch.cuda.device_count()):
76+
assert torch.cuda.get_device_name(i) == CUDAAccelerator.device_name(i)
77+
78+
with torch.device("cuda:0"):
79+
assert torch.cuda.get_device_name(0) == CUDAAccelerator.device_name()
80+
81+
82+
def test_gpu_device_name_no_gpu():
83+
with mock.patch("torch.cuda.is_available", return_value=False):
84+
assert str(False) == CUDAAccelerator.device_name()

tests/tests_pytorch/accelerators/test_mps.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from collections import namedtuple
16+
from unittest import mock
1617

1718
import pytest
1819
import torch
@@ -39,6 +40,16 @@ def test_mps_availability():
3940
assert MPSAccelerator.is_available()
4041

4142

43+
@RunIf(mps=True)
44+
def test_mps_device_name():
45+
assert MPSAccelerator.device_name() == "True (mps)"
46+
47+
48+
def test_mps_device_name_not_available():
49+
with mock.patch("torch.backends.mps.is_available", return_value=False):
50+
assert MPSAccelerator.device_name() == "False"
51+
52+
4253
def test_warning_if_mps_not_used(mps_count_1):
4354
with pytest.warns(UserWarning, match="GPU available but not used"):
4455
Trainer(accelerator="cpu")

tests/tests_pytorch/accelerators/test_xla.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,11 @@ def test_warning_if_tpus_not_used(tpu_available):
302302
Trainer(accelerator="cpu")
303303

304304

305+
@RunIf(tpu=True)
306+
def test_tpu_device_name():
307+
assert XLAAccelerator.device_name() == "TPU"
308+
309+
305310
@pytest.mark.parametrize(
306311
("devices", "expected_device_ids"),
307312
[

tests/tests_pytorch/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def thread_police_duuu_daaa_duuu_daaa():
182182
def mock_cuda_count(monkeypatch, n: int) -> None:
183183
monkeypatch.setattr(lightning.fabric.accelerators.cuda, "num_cuda_devices", lambda: n)
184184
monkeypatch.setattr(lightning.pytorch.accelerators.cuda, "num_cuda_devices", lambda: n)
185+
monkeypatch.setattr(torch.cuda, "get_device_name", lambda _: "Mocked CUDA Device")
185186

186187

187188
@pytest.fixture
@@ -244,6 +245,7 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N
244245
monkeypatch.setitem(sys.modules, "torch_xla", Mock())
245246
monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock())
246247
monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock())
248+
monkeypatch.setattr(lightning.pytorch.accelerators.xla.XLAAccelerator, "device_name", lambda _: "TPU")
247249

248250

249251
@pytest.fixture

0 commit comments

Comments
 (0)