Skip to content

Commit 037a24b

Browse files
committed
feat: add device_name classmethod in Accelerator.
1 parent 2460746 commit 037a24b

File tree

6 files changed

+57
-17
lines changed

6 files changed

+57
-17
lines changed

src/lightning/pytorch/accelerators/accelerator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from abc import ABC
15-
from typing import Any
15+
from typing import Any, Optional
1616

1717
import lightning.pytorch as pl
1818
from lightning.fabric.accelerators.accelerator import Accelerator as _Accelerator
@@ -45,3 +45,8 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
4545
4646
"""
4747
raise NotImplementedError
48+
49+
@classmethod
50+
def device_name(cls, device: Optional = None) -> str:
51+
"""Get the device name for a given device."""
52+
return str(cls.is_available())

src/lightning/pytorch/accelerators/cuda.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
113113
description=cls.__name__,
114114
)
115115

116+
@classmethod
117+
def device_name(cls, device: Optional[torch.types.Device] = None) -> str:
118+
if not cls.is_available():
119+
return "False"
120+
return torch.cuda.get_device_name(device)
121+
116122

117123
def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]: # pragma: no-cover
118124
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.

src/lightning/pytorch/accelerators/mps.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
8787
description=cls.__name__,
8888
)
8989

90+
@classmethod
91+
def device_name(cls, device: Optional = None) -> str:
92+
# todo: implement a better way to get the device name
93+
available = cls.is_available()
94+
gpu_type = " (mps)" if available else ""
95+
return f"{available}{gpu_type}"
96+
9097

9198
# device metrics
9299
_VM_PERCENT = "M1_vm_percent"

src/lightning/pytorch/accelerators/xla.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any
14+
from typing import Any, Optional
1515

1616
from typing_extensions import override
1717

1818
from lightning.fabric.accelerators import _AcceleratorRegistry
19+
from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
1920
from lightning.fabric.accelerators.xla import XLAAccelerator as FabricXLAAccelerator
2021
from lightning.fabric.utilities.types import _DEVICE
2122
from lightning.pytorch.accelerators.accelerator import Accelerator
@@ -53,3 +54,24 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
5354
@override
5455
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
5556
accelerator_registry.register("tpu", cls, description=cls.__name__)
57+
58+
@classmethod
59+
def device_name(cls, device: Optional = None) -> str:
60+
is_available = cls.is_available()
61+
if not is_available:
62+
return str(is_available)
63+
64+
if _XLA_GREATER_EQUAL_2_1:
65+
from torch_xla._internal import tpu
66+
else:
67+
from torch_xla.experimental import tpu
68+
import torch_xla.core.xla_env_vars as xenv
69+
from requests.exceptions import HTTPError
70+
71+
try:
72+
ret = tpu.get_tpu_env()[xenv.ACCELERATOR_TYPE]
73+
except HTTPError:
74+
# Fallback to "True" if HTTPError is raised during retrieving device information
75+
ret = str(is_available)
76+
77+
return ret

src/lightning/pytorch/trainer/setup.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -142,21 +142,16 @@ def _init_profiler(trainer: "pl.Trainer", profiler: Optional[Union[Profiler, str
142142

143143

144144
def _log_device_info(trainer: "pl.Trainer") -> None:
145-
if CUDAAccelerator.is_available():
146-
gpu_available = True
147-
gpu_type = " (cuda)"
148-
elif MPSAccelerator.is_available():
149-
gpu_available = True
150-
gpu_type = " (mps)"
145+
if isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator)):
146+
gpu_used = trainer.num_devices
147+
device_names = list({trainer.accelerator.device_name(d) for d in trainer.devices})
151148
else:
152-
gpu_available = False
153-
gpu_type = ""
154-
155-
gpu_used = isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator))
156-
rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}")
149+
gpu_used = 0
150+
device_names = "False"
151+
rank_zero_info(f"GPU available: {device_names}, using: {gpu_used} {'devices' if gpu_used else 'device'}.")
157152

158153
num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, XLAAccelerator) else 0
159-
rank_zero_info(f"TPU available: {XLAAccelerator.is_available()}, using: {num_tpu_cores} TPU cores")
154+
rank_zero_info(f"TPU available: {XLAAccelerator.device_name()}, using: {num_tpu_cores} TPU cores")
160155

161156
if _habana_available_and_importable():
162157
from lightning_habana import HPUAccelerator

src/lightning/pytorch/trainer/trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,16 +1187,21 @@ def num_nodes(self) -> int:
11871187
return getattr(self.strategy, "num_nodes", 1)
11881188

11891189
@property
1190-
def device_ids(self) -> list[int]:
1191-
"""List of device indexes per node."""
1190+
def devices(self) -> list[torch.device]:
1191+
"""The devices the trainer uses per node."""
11921192
devices = (
11931193
self.strategy.parallel_devices
11941194
if isinstance(self.strategy, ParallelStrategy)
11951195
else [self.strategy.root_device]
11961196
)
11971197
assert devices is not None
1198+
return devices
1199+
1200+
@property
1201+
def device_ids(self) -> list[int]:
1202+
"""List of device indexes per node."""
11981203
device_ids = []
1199-
for idx, device in enumerate(devices):
1204+
for idx, device in enumerate(self.devices):
12001205
if isinstance(device, torch.device):
12011206
device_ids.append(device.index or idx)
12021207
elif isinstance(device, int):

0 commit comments

Comments
 (0)