Skip to content

Commit b2697b5

Browse files
committed
fix mock issues with devices
1 parent 492115e commit b2697b5

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -491,10 +491,11 @@ def test_strategy_choice_ddp_torchelastic(_, __, mps_count_0, cuda_count_2):
491491
"LOCAL_RANK": "1",
492492
},
493493
)
494-
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
495-
@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
496-
def test_torchelastic_priority_over_slurm(*_):
494+
def test_torchelastic_priority_over_slurm(monkeypatch):
497495
"""Test that the TorchElastic cluster environment is chosen over SLURM when both are detected."""
496+
mock_cuda_count(monkeypatch, 2)
497+
mock_mps_count(monkeypatch, 0)
498+
mock_hpu_count(monkeypatch, 0)
498499
assert TorchElasticEnvironment.detect()
499500
assert SLURMEnvironment.detect()
500501
connector = _AcceleratorConnector(strategy="ddp")
@@ -1003,6 +1004,7 @@ def _mock_tpu_available(value):
10031004
with monkeypatch.context():
10041005
mock_cuda_count(monkeypatch, 2)
10051006
mock_mps_count(monkeypatch, 0)
1007+
mock_hpu_count(monkeypatch, 0)
10061008
_mock_tpu_available(True)
10071009
connector = _AcceleratorConnector()
10081010
assert isinstance(connector.accelerator, XLAAccelerator)

0 commit comments

Comments
 (0)