Skip to content

Commit 0a5725b

Browse files
committed
fix tests.
1 parent 92b1d69 commit 0a5725b

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

tests/tests_pytorch/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,11 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N
245245
monkeypatch.setitem(sys.modules, "torch_xla", Mock())
246246
monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock())
247247
monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock())
248-
monkeypatch.setattr(lightning.pytorch.accelerators.xla.XLAAccelerator, "device_name", lambda _: "TPU")
248+
monkeypatch.setattr(
249+
lightning.pytorch.accelerators.xla.XLAAccelerator,
250+
"device_name",
251+
lambda *_: "Mocked TPU Device",
252+
)
249253

250254

251255
@pytest.fixture

tests/tests_pytorch/plugins/test_cluster_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_ranks_available_manual_strategy_selection(_, strategy_cls):
6666
"""Test that the rank information is readily available after Trainer initialization."""
6767
num_nodes = 2
6868
for cluster, variables, expected in environment_combinations():
69-
with mock.patch.dict(os.environ, variables):
69+
with mock.patch.dict(os.environ, variables), mock.patch("torch.cuda.get_device_name", return_value="GPU"):
7070
strategy = strategy_cls(
7171
parallel_devices=[torch.device("cuda", 1), torch.device("cuda", 2)], cluster_environment=cluster
7272
)

0 commit comments

Comments
 (0)