Skip to content

Commit 475d7d0

Browse files
authored
Add cpu device parser to validate cpu devices (#12160)
1 parent d9b1ff3 commit 475d7d0

File tree

6 files changed

+41
-20
lines changed

6 files changed

+41
-20
lines changed

pytorch_lightning/accelerators/cpu.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import torch
1717

1818
from pytorch_lightning.accelerators.accelerator import Accelerator
19+
from pytorch_lightning.utilities import device_parser
1920
from pytorch_lightning.utilities.exceptions import MisconfigurationException
20-
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
2121
from pytorch_lightning.utilities.types import _DEVICE
2222

2323

@@ -39,19 +39,16 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
3939
return {}
4040

4141
@staticmethod
42-
def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, str, List[int]]:
42+
def parse_devices(devices: Union[int, str, List[int]]) -> int:
4343
"""Accelerator device parsing logic."""
44+
devices = device_parser.parse_cpu_cores(devices)
4445
return devices
4546

4647
@staticmethod
4748
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
4849
"""Gets parallel devices for the Accelerator."""
49-
if isinstance(devices, int):
50-
return [torch.device("cpu")] * devices
51-
rank_zero_warn(
52-
f"The flag `devices` must be an int with `accelerator='cpu'`, got `devices={devices!r}` instead."
53-
)
54-
return []
50+
devices = device_parser.parse_cpu_cores(devices)
51+
return [torch.device("cpu")] * devices
5552

5653
@staticmethod
5754
def auto_device_count() -> int:

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
512512
self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag)
513513

514514
def _set_devices_flag_if_auto_passed(self) -> None:
515-
if self._devices_flag == "auto" or not self._devices_flag:
515+
if self._devices_flag == "auto" or self._devices_flag is None:
516516
self._devices_flag = self.accelerator.auto_device_count()
517517

518518
def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:

pytorch_lightning/utilities/device_parser.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,9 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
7979
Returns:
8080
a list of gpus to be used or ``None`` if no GPUs were requested
8181
82-
If no GPUs are available but the value of gpus variable indicates request for GPUs
83-
then a MisconfigurationException is raised.
82+
Raises:
83+
MisconfigurationException:
84+
If no GPUs are available but the value of gpus variable indicates request for GPUs
8485
"""
8586
# Check that gpus param is None, Int, String or List
8687
_check_data_type(gpus)
@@ -137,6 +138,29 @@ def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional
137138
return tpu_cores
138139

139140

141+
def parse_cpu_cores(cpu_cores: Union[int, str, List[int]]) -> int:
142+
"""Parses the cpu_cores given in the format as accepted by the ``devices`` argument in the
143+
:class:`~pytorch_lightning.trainer.Trainer`.
144+
145+
Args:
146+
cpu_cores: An int > 0.
147+
148+
Returns:
149+
an int representing the number of processes
150+
151+
Raises:
152+
MisconfigurationException:
153+
If cpu_cores is not an int > 0
154+
"""
155+
if isinstance(cpu_cores, str) and cpu_cores.strip().isdigit():
156+
cpu_cores = int(cpu_cores)
157+
158+
if not isinstance(cpu_cores, int) or cpu_cores <= 0:
159+
raise MisconfigurationException("`devices` selected with `CPUAccelerator` should be an int > 0.")
160+
161+
return cpu_cores
162+
163+
140164
def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]:
141165
if not isinstance(s, str):
142166
return s

pytorch_lightning/utilities/parsing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def save_hyperparameters(
243243
if isinstance(init_args[k], nn.Module):
244244
rank_zero_warn(
245245
f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."
246-
" It is recommended to ignore them using `self.save_hyperparameters(ignore=[k!r])`."
246+
f" It is recommended to ignore them using `self.save_hyperparameters(ignore=[{k}!r])`."
247247
)
248248

249249
if not args:

tests/accelerators/test_accelerator_connector.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -570,14 +570,6 @@ def test_set_devices_if_none_cpu():
570570
assert trainer.num_devices == 3
571571

572572

573-
def test_devices_with_cpu_only_supports_integer():
574-
575-
with pytest.warns(UserWarning, match="The flag `devices` must be an int"):
576-
trainer = Trainer(accelerator="cpu", devices="1,3")
577-
assert isinstance(trainer.accelerator, CPUAccelerator)
578-
assert trainer.num_devices == 1
579-
580-
581573
@pytest.mark.parametrize("training_type", ["ddp2", "dp"])
582574
def test_unsupported_strategy_types_on_cpu(training_type):
583575
with pytest.warns(UserWarning, match="is not supported on CPUs, hence setting `strategy='ddp"):

tests/accelerators/test_cpu.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
1212
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
1313
from pytorch_lightning.strategies import SingleDeviceStrategy
14+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1415
from tests.helpers.boring_model import BoringModel
1516

1617

@@ -66,3 +67,10 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
6667
for func in (trainer.test, trainer.validate, trainer.predict):
6768
plugin.setup_called = False
6869
func(model, ckpt_path=checkpoint_path)
70+
71+
72+
@pytest.mark.parametrize("devices", ([3], -1, 0))
73+
def test_invalid_devices_with_cpu_accelerator(devices):
74+
"""Test invalid device flag raises MisconfigurationException with CPUAccelerator."""
75+
with pytest.raises(MisconfigurationException, match="should be an int > 0"):
76+
Trainer(accelerator="cpu", devices=devices)

0 commit comments

Comments
 (0)