Skip to content

Commit 32442c2

Browse files
authored
Add name to accelerator interface (#21325)
1 parent 739cf13 commit 32442c2

File tree

11 files changed

+76
-12
lines changed

11 files changed

+76
-12
lines changed

src/lightning/fabric/accelerators/accelerator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ def auto_device_count() -> int:
5656
def is_available() -> bool:
5757
"""Detect if the hardware is available."""
5858

59+
@staticmethod
60+
@abstractmethod
61+
def name() -> str:
62+
"""The name of the accelerator."""
63+
5964
@classmethod
6065
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
66+
"""Register the accelerator with the registry."""
6167
pass

src/lightning/fabric/accelerators/cpu.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,16 @@ def is_available() -> bool:
6262
"""CPU is always available for execution."""
6363
return True
6464

65+
@staticmethod
66+
@override
67+
def name() -> str:
68+
return "cpu"
69+
6570
@classmethod
6671
@override
6772
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
6873
accelerator_registry.register(
69-
"cpu",
74+
cls.name(),
7075
cls,
7176
description=cls.__name__,
7277
)

src/lightning/fabric/accelerators/cuda.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,16 @@ def auto_device_count() -> int:
6666
def is_available() -> bool:
6767
return num_cuda_devices() > 0
6868

69+
@staticmethod
70+
@override
71+
def name() -> str:
72+
return "cuda"
73+
6974
@classmethod
7075
@override
7176
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
7277
accelerator_registry.register(
73-
"cuda",
78+
cls.name(),
7479
cls,
7580
description=cls.__name__,
7681
)

src/lightning/fabric/accelerators/mps.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,16 @@ def is_available() -> bool:
7474
mps_disabled = os.getenv("DISABLE_MPS", "0") == "1"
7575
return not mps_disabled and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64")
7676

77+
@staticmethod
78+
@override
79+
def name() -> str:
80+
return "mps"
81+
7782
@classmethod
7883
@override
7984
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
8085
accelerator_registry.register(
81-
"mps",
86+
cls.name(),
8287
cls,
8388
description=cls.__name__,
8489
)

src/lightning/fabric/accelerators/xla.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,19 @@ def is_available() -> bool:
9393
# when `torch_xla` is imported but not used
9494
return False
9595

96+
@staticmethod
97+
@override
98+
def name() -> str:
99+
return "tpu"
100+
96101
@classmethod
97102
@override
98103
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
99-
accelerator_registry.register("tpu", cls, description=cls.__name__)
104+
accelerator_registry.register(
105+
cls.name(),
106+
cls,
107+
description=cls.__name__,
108+
)
100109

101110

102111
# PJRT support requires this minimum version

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3232

3333
- Add MPS accelerator support for mixed precision ([#21209](https://github.com/Lightning-AI/pytorch-lightning/pull/21209))
3434

35+
- Add `name()` function to accelerator interface (([#21325](https://github.com/Lightning-AI/pytorch-lightning/pull/21325)))
36+
3537

3638
### Removed
3739

src/lightning/pytorch/accelerators/cpu.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from lightning_utilities.core.imports import RequirementCache
1818
from typing_extensions import override
1919

20-
from lightning.fabric.accelerators import _AcceleratorRegistry
2120
from lightning.fabric.accelerators.cpu import _parse_cpu_cores
21+
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
2222
from lightning.fabric.utilities.types import _DEVICE
2323
from lightning.pytorch.accelerators.accelerator import Accelerator
2424
from lightning.pytorch.utilities.exceptions import MisconfigurationException
@@ -71,11 +71,16 @@ def is_available() -> bool:
7171
"""CPU is always available for execution."""
7272
return True
7373

74+
@staticmethod
75+
@override
76+
def name() -> str:
77+
return "cpu"
78+
7479
@classmethod
7580
@override
7681
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
7782
accelerator_registry.register(
78-
"cpu",
83+
cls.name(),
7984
cls,
8085
description=cls.__name__,
8186
)

src/lightning/pytorch/accelerators/cuda.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from typing_extensions import override
2222

2323
import lightning.pytorch as pl
24-
from lightning.fabric.accelerators import _AcceleratorRegistry
2524
from lightning.fabric.accelerators.cuda import _check_cuda_matmul_precision, _clear_cuda_memory, num_cuda_devices
25+
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
2626
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
2727
from lightning.fabric.utilities.types import _DEVICE
2828
from lightning.pytorch.accelerators.accelerator import Accelerator
@@ -104,11 +104,16 @@ def auto_device_count() -> int:
104104
def is_available() -> bool:
105105
return num_cuda_devices() > 0
106106

107+
@staticmethod
108+
@override
109+
def name() -> str:
110+
return "cuda"
111+
107112
@classmethod
108113
@override
109114
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
110115
accelerator_registry.register(
111-
"cuda",
116+
cls.name(),
112117
cls,
113118
description=cls.__name__,
114119
)

src/lightning/pytorch/accelerators/mps.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import torch
1717
from typing_extensions import override
1818

19-
from lightning.fabric.accelerators import _AcceleratorRegistry
2019
from lightning.fabric.accelerators.mps import MPSAccelerator as _MPSAccelerator
20+
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
2121
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
2222
from lightning.fabric.utilities.types import _DEVICE
2323
from lightning.pytorch.accelerators.accelerator import Accelerator
@@ -78,11 +78,16 @@ def is_available() -> bool:
7878
"""MPS is only available on a machine with the ARM-based Apple Silicon processors."""
7979
return _MPSAccelerator.is_available()
8080

81+
@staticmethod
82+
@override
83+
def name() -> str:
84+
return "mps"
85+
8186
@classmethod
8287
@override
8388
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
8489
accelerator_registry.register(
85-
"mps",
90+
cls.name(),
8691
cls,
8792
description=cls.__name__,
8893
)

src/lightning/pytorch/accelerators/xla.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from typing_extensions import override
1717

18-
from lightning.fabric.accelerators import _AcceleratorRegistry
18+
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
1919
from lightning.fabric.accelerators.xla import XLAAccelerator as FabricXLAAccelerator
2020
from lightning.fabric.utilities.types import _DEVICE
2121
from lightning.pytorch.accelerators.accelerator import Accelerator
@@ -49,7 +49,16 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
4949
"avg. peak memory (MB)": peak_memory,
5050
}
5151

52+
@staticmethod
53+
@override
54+
def name() -> str:
55+
return "tpu"
56+
5257
@classmethod
5358
@override
5459
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
55-
accelerator_registry.register("tpu", cls, description=cls.__name__)
60+
accelerator_registry.register(
61+
cls.name(),
62+
cls,
63+
description=cls.__name__,
64+
)

0 commit comments

Comments
 (0)