File tree Expand file tree Collapse file tree 5 files changed +18
-25
lines changed
src/lightning/fabric/accelerators Expand file tree Collapse file tree 5 files changed +18
-25
lines changed Original file line number Diff line number Diff line change @@ -56,6 +56,11 @@ def auto_device_count() -> int:
5656 def is_available () -> bool :
5757 """Detect if the hardware is available."""
5858
59+ @staticmethod
60+ @abstractmethod
61+ def name () -> str :
62+ """The name of the accelerator."""
63+
5964 @classmethod
6065 def register_accelerators (cls , accelerator_registry : _AcceleratorRegistry ) -> None :
61- pass
66+ accelerator_registry . register ( cls . name (), cls , description = cls . __name__ )
Original file line number Diff line number Diff line change @@ -62,14 +62,10 @@ def is_available() -> bool:
6262 """CPU is always available for execution."""
6363 return True
6464
65- @classmethod
65+ @staticmethod
6666 @override
67- def register_accelerators (cls , accelerator_registry : _AcceleratorRegistry ) -> None :
68- accelerator_registry .register (
69- "cpu" ,
70- cls ,
71- description = cls .__name__ ,
72- )
67+ def name () -> str :
68+ return "cpu"
7369
7470
7571def _parse_cpu_cores (cpu_cores : Union [int , str ]) -> int :
Original file line number Diff line number Diff line change @@ -66,14 +66,10 @@ def auto_device_count() -> int:
6666 def is_available () -> bool :
6767 return num_cuda_devices () > 0
6868
69- @classmethod
69+ @staticmethod
7070 @override
71- def register_accelerators (cls , accelerator_registry : _AcceleratorRegistry ) -> None :
72- accelerator_registry .register (
73- "cuda" ,
74- cls ,
75- description = cls .__name__ ,
76- )
71+ def name () -> str :
72+ return "cuda"
7773
7874
7975def find_usable_cuda_devices (num_devices : int = - 1 ) -> list [int ]:
Original file line number Diff line number Diff line change @@ -74,14 +74,10 @@ def is_available() -> bool:
7474 mps_disabled = os .getenv ("DISABLE_MPS" , "0" ) == "1"
7575 return not mps_disabled and torch .backends .mps .is_available () and platform .processor () in ("arm" , "arm64" )
7676
77- @classmethod
77+ @staticmethod
7878 @override
79- def register_accelerators (cls , accelerator_registry : _AcceleratorRegistry ) -> None :
80- accelerator_registry .register (
81- "mps" ,
82- cls ,
83- description = cls .__name__ ,
84- )
79+ def name () -> str :
80+ return "mps"
8581
8682
8783def _get_all_available_mps_gpus () -> list [int ]:
Original file line number Diff line number Diff line change @@ -93,10 +93,10 @@ def is_available() -> bool:
9393 # when `torch_xla` is imported but not used
9494 return False
9595
96- @classmethod
96+ @staticmethod
9797 @override
98- def register_accelerators ( cls , accelerator_registry : _AcceleratorRegistry ) -> None :
99- accelerator_registry . register ( "tpu" , cls , description = cls . __name__ )
98+ def name ( ) -> str :
99+ return "tpu"
100100
101101
102102# PJRT support requires this minimum version
You can’t perform that action at this time.
0 commit comments