Skip to content

Commit ce08acc

Browse files
committed
add name
1 parent b554e99 commit ce08acc

File tree

5 files changed

+18
-25
lines changed

5 files changed

+18
-25
lines changed

src/lightning/fabric/accelerators/accelerator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ def auto_device_count() -> int:
5656
def is_available() -> bool:
5757
"""Detect if the hardware is available."""
5858

59+
@staticmethod
60+
@abstractmethod
61+
def name() -> str:
62+
"""The name of the accelerator."""
63+
5964
@classmethod
6065
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
61-
pass
66+
accelerator_registry.register(cls.name(), cls, description=cls.__name__)

src/lightning/fabric/accelerators/cpu.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,10 @@ def is_available() -> bool:
6262
"""CPU is always available for execution."""
6363
return True
6464

65-
@classmethod
65+
@staticmethod
6666
@override
67-
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
68-
accelerator_registry.register(
69-
"cpu",
70-
cls,
71-
description=cls.__name__,
72-
)
67+
def name() -> str:
68+
return "cpu"
7369

7470

7571
def _parse_cpu_cores(cpu_cores: Union[int, str]) -> int:

src/lightning/fabric/accelerators/cuda.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,10 @@ def auto_device_count() -> int:
6666
def is_available() -> bool:
6767
return num_cuda_devices() > 0
6868

69-
@classmethod
69+
@staticmethod
7070
@override
71-
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
72-
accelerator_registry.register(
73-
"cuda",
74-
cls,
75-
description=cls.__name__,
76-
)
71+
def name() -> str:
72+
return "cuda"
7773

7874

7975
def find_usable_cuda_devices(num_devices: int = -1) -> list[int]:

src/lightning/fabric/accelerators/mps.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,10 @@ def is_available() -> bool:
7474
mps_disabled = os.getenv("DISABLE_MPS", "0") == "1"
7575
return not mps_disabled and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64")
7676

77-
@classmethod
77+
@staticmethod
7878
@override
79-
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
80-
accelerator_registry.register(
81-
"mps",
82-
cls,
83-
description=cls.__name__,
84-
)
79+
def name() -> str:
80+
return "mps"
8581

8682

8783
def _get_all_available_mps_gpus() -> list[int]:

src/lightning/fabric/accelerators/xla.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ def is_available() -> bool:
9393
# when `torch_xla` is imported but not used
9494
return False
9595

96-
@classmethod
96+
@staticmethod
9797
@override
98-
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
99-
accelerator_registry.register("tpu", cls, description=cls.__name__)
98+
def name() -> str:
99+
return "tpu"
100100

101101

102102
# PJRT support requires this minimum version

0 commit comments

Comments
 (0)