Skip to content

Commit 20d19d2

Browse files
chualanagitAlan ChuBorda
authored
Remove List[int] as input type for Trainer when accelerator="cpu" (#20399)
Co-authored-by: Alan Chu <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent e1b172c commit 20d19d2

File tree

2 files changed

+6
-6
lines changed
  • src/lightning

2 files changed

+6
-6
lines changed

src/lightning/fabric/accelerators/cpu.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@ def teardown(self) -> None:
3939

4040
@staticmethod
4141
@override
42-
def parse_devices(devices: Union[int, str, List[int]]) -> int:
42+
def parse_devices(devices: Union[int, str]) -> int:
4343
"""Accelerator device parsing logic."""
4444
return _parse_cpu_cores(devices)
4545

4646
@staticmethod
4747
@override
48-
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
48+
def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]:
4949
"""Gets parallel devices for the Accelerator."""
5050
devices = _parse_cpu_cores(devices)
5151
return [torch.device("cpu")] * devices
@@ -72,12 +72,12 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
7272
)
7373

7474

75-
def _parse_cpu_cores(cpu_cores: Union[int, str, List[int]]) -> int:
75+
def _parse_cpu_cores(cpu_cores: Union[int, str]) -> int:
7676
"""Parses the cpu_cores given in the format as accepted by the ``devices`` argument in the
7777
:class:`~lightning.pytorch.trainer.trainer.Trainer`.
7878
7979
Args:
80-
cpu_cores: An int > 0.
80+
cpu_cores: An int > 0 or a string that can be converted to an int > 0.
8181
8282
Returns:
8383
An int representing the number of processes

src/lightning/pytorch/accelerators/cpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ def teardown(self) -> None:
4848

4949
@staticmethod
5050
@override
51-
def parse_devices(devices: Union[int, str, List[int]]) -> int:
51+
def parse_devices(devices: Union[int, str]) -> int:
5252
"""Accelerator device parsing logic."""
5353
return _parse_cpu_cores(devices)
5454

5555
@staticmethod
5656
@override
57-
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
57+
def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]:
5858
"""Gets parallel devices for the Accelerator."""
5959
devices = _parse_cpu_cores(devices)
6060
return [torch.device("cpu")] * devices

0 commit comments

Comments
 (0)