Skip to content

Commit 041da41

Browse files
authored
Remove TPU Availability check from parse devices (#12326)
* Remove TPU Availability check from parse devices * Update tests
1 parent 4fe0076 commit 041da41

File tree

3 files changed

+5
-12
lines changed

3 files changed

+5
-12
lines changed

pytorch_lightning/utilities/device_parser.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
1919
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
20-
from pytorch_lightning.utilities import _TPU_AVAILABLE
2120
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2221
from pytorch_lightning.utilities.types import _DEVICE
2322

@@ -122,7 +121,7 @@ def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional
122121
123122
Raises:
124123
MisconfigurationException:
125-
If TPU cores aren't 1 or 8 cores, or no TPU devices are found
124+
If TPU cores aren't 1, 8 or [<1-8>]
126125
"""
127126
_check_data_type(tpu_cores)
128127

@@ -132,9 +131,6 @@ def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional
132131
if not _tpu_cores_valid(tpu_cores):
133132
raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]")
134133

135-
if tpu_cores is not None and not _TPU_AVAILABLE:
136-
raise MisconfigurationException("No TPU devices were found.")
137-
138134
return tpu_cores
139135

140136

tests/accelerators/test_accelerator_connector.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -446,8 +446,7 @@ def test_ipython_compatible_dp_strategy_gpu(_, monkeypatch):
446446

447447

448448
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
449-
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8)
450-
def test_ipython_compatible_strategy_tpu(mock_devices, mock_tpu_acc_avail, monkeypatch):
449+
def test_ipython_compatible_strategy_tpu(mock_tpu_acc_avail, monkeypatch):
451450
monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True)
452451
trainer = Trainer(accelerator="tpu")
453452
assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible
@@ -894,8 +893,7 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock
894893

895894

896895
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
897-
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8)
898-
def test_unsupported_tpu_choice(mock_devices, mock_tpu_acc_avail):
896+
def test_unsupported_tpu_choice(mock_tpu_acc_avail):
899897

900898
with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"):
901899
Trainer(accelerator="tpu", precision=64)

tests/deprecated_api/test_remove_1-8.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,9 +1139,8 @@ def test_trainer_gpus(monkeypatch, trainer_kwargs):
11391139

11401140

11411141
def test_trainer_tpu_cores(monkeypatch):
1142-
monkeypatch.setattr(pytorch_lightning.accelerators.tpu.TPUAccelerator, "is_available", lambda: True)
1143-
monkeypatch.setattr(pytorch_lightning.accelerators.tpu.TPUAccelerator, "parse_devices", lambda: 8)
1144-
trainer = Trainer(accelerator="TPU", devices=8)
1142+
monkeypatch.setattr(pytorch_lightning.accelerators.tpu.TPUAccelerator, "is_available", lambda _: True)
1143+
trainer = Trainer(accelerator="tpu", devices=8)
11451144
with pytest.deprecated_call(
11461145
match="`Trainer.tpu_cores` is deprecated in v1.6 and will be removed in v1.8. "
11471146
"Please use `Trainer.num_devices` instead."

0 commit comments

Comments
 (0)