Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/lightning/fabric/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def teardown(self) -> None:

@staticmethod
@override
def parse_devices(devices: Union[int, str, List[int]]) -> int:
def parse_devices(devices: Union[int, str]) -> int:
"""Accelerator device parsing logic."""
return _parse_cpu_cores(devices)

@staticmethod
@override
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = _parse_cpu_cores(devices)
return [torch.device("cpu")] * devices
Expand All @@ -72,12 +72,12 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
)


def _parse_cpu_cores(cpu_cores: Union[int, str, List[int]]) -> int:
def _parse_cpu_cores(cpu_cores: Union[int, str]) -> int:
"""Parses the cpu_cores given in the format as accepted by the ``devices`` argument in the
:class:`~lightning.pytorch.trainer.trainer.Trainer`.

Args:
cpu_cores: An int > 0.
cpu_cores: An int > 0 or a string that can be converted to an int > 0.

Returns:
An int representing the number of processes
Expand Down
Loading