Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
037a24b
feat: add `device_name` classmethod in `Accelerator`.
GdoongMathew Aug 24, 2025
c1746b2
feat: change to original device logic in setup.
GdoongMathew Aug 24, 2025
a383d79
revert: revert changes in trainer.
GdoongMathew Aug 24, 2025
7db8793
fix: fix type annotation.
GdoongMathew Aug 24, 2025
3a3949a
fix: fix type annotation.
GdoongMathew Aug 24, 2025
8cb809f
add `override` decorator.
GdoongMathew Aug 26, 2025
944ad69
fix tests.
GdoongMathew Aug 26, 2025
1d9bd0d
fix tests.
GdoongMathew Aug 26, 2025
92b1d69
fix mypy and device string format.
GdoongMathew Aug 26, 2025
0a5725b
fix tests.
GdoongMathew Aug 26, 2025
5b11bdb
mps override decorator.
GdoongMathew Aug 27, 2025
4abcd13
Merge branch 'master' into feat/device_name
Borda Aug 27, 2025
63a0f70
empty str
Borda Aug 27, 2025
8e91f8f
return empty string if accelerator is not available.
GdoongMathew Aug 27, 2025
f124d55
fix: fix unittests.
GdoongMathew Aug 27, 2025
aa16731
Merge branch 'master' into feat/device_name
GdoongMathew Aug 28, 2025
eddc009
Merge branch 'master' into feat/device_name
GdoongMathew Aug 29, 2025
0cfa761
Merge branch 'master' into feat/device_name
GdoongMathew Sep 5, 2025
724a01c
Merge branch 'master' into feat/device_name
GdoongMathew Sep 8, 2025
fe6121f
Merge branch 'master' into feat/device_name
GdoongMathew Sep 15, 2025
388d786
Merge branch 'master' into feat/device_name
GdoongMathew Sep 16, 2025
9ec946f
Merge branch 'master' into feat/device_name
Borda Sep 25, 2025
eb53f4c
Merge branch 'master' into feat/device_name
GdoongMathew Sep 30, 2025
87a021d
Merge branch 'master' into feat/device_name
GdoongMathew Oct 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/lightning/pytorch/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from typing import Any
from typing import Any, Optional

import lightning.pytorch as pl
from lightning.fabric.accelerators.accelerator import Accelerator as _Accelerator
Expand Down Expand Up @@ -45,3 +45,8 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:

"""
raise NotImplementedError

@classmethod
def device_name(cls, device: Optional[_DEVICE] = None) -> str:
"""Get the device name for a given device."""
return str(cls.is_available())
7 changes: 7 additions & 0 deletions src/lightning/pytorch/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
description=cls.__name__,
)

@classmethod
@override
def device_name(cls, device: Optional[_DEVICE] = None) -> str:
if not cls.is_available():
return ""
return torch.cuda.get_device_name(device)


def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]: # pragma: no-cover
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.
Expand Down
8 changes: 8 additions & 0 deletions src/lightning/pytorch/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
description=cls.__name__,
)

@classmethod
@override
def device_name(cls, device: Optional[_DEVICE] = None) -> str:
# todo: implement a better way to get the device name
if not cls.is_available():
return ""
return "True (mps)"


# device metrics
_VM_PERCENT = "M1_vm_percent"
Expand Down
25 changes: 24 additions & 1 deletion src/lightning/pytorch/accelerators/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, Optional

from typing_extensions import override

from lightning.fabric.accelerators import _AcceleratorRegistry
from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
from lightning.fabric.accelerators.xla import XLAAccelerator as FabricXLAAccelerator
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.accelerators.accelerator import Accelerator
Expand Down Expand Up @@ -53,3 +54,25 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register("tpu", cls, description=cls.__name__)

@classmethod
@override
def device_name(cls, device: Optional[_DEVICE] = None) -> str:
is_available = cls.is_available()
if not is_available:
return ""

if _XLA_GREATER_EQUAL_2_1:
from torch_xla._internal import tpu
else:
from torch_xla.experimental import tpu
import torch_xla.core.xla_env_vars as xenv
from requests.exceptions import HTTPError

try:
ret = tpu.get_tpu_env()[xenv.ACCELERATOR_TYPE]
except HTTPError:
# Fallback to "True" if HTTPError is raised during retrieving device information
ret = str(is_available)

return ret
21 changes: 12 additions & 9 deletions src/lightning/pytorch/trainer/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,23 @@ def _init_profiler(trainer: "pl.Trainer", profiler: Optional[Union[Profiler, str

def _log_device_info(trainer: "pl.Trainer") -> None:
if CUDAAccelerator.is_available():
gpu_available = True
gpu_type = " (cuda)"
if isinstance(trainer.accelerator, CUDAAccelerator):
device_name = ", ".join(list({CUDAAccelerator.device_name(d) for d in trainer.device_ids}))
else:
device_name = CUDAAccelerator.device_name()
elif MPSAccelerator.is_available():
gpu_available = True
gpu_type = " (mps)"
device_name = MPSAccelerator.device_name()
else:
gpu_available = False
gpu_type = ""
device_name = str(False)

gpu_used = isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator))
rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}")
gpu_used = trainer.num_devices if isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator)) else 0
rank_zero_info(f"GPU available: {device_name}, using: {gpu_used} devices.")

num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, XLAAccelerator) else 0
rank_zero_info(f"TPU available: {XLAAccelerator.is_available()}, using: {num_tpu_cores} TPU cores")
rank_zero_info(
f"TPU available: {XLAAccelerator.device_name() if XLAAccelerator.is_available() else str(False)}, "
f"using: {num_tpu_cores} TPU cores"
)

if _habana_available_and_importable():
from lightning_habana import HPUAccelerator
Expand Down
13 changes: 13 additions & 0 deletions tests/tests_pytorch/accelerators/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,16 @@ def test_gpu_availability():
def test_warning_if_gpus_not_used(cuda_count_1):
with pytest.warns(UserWarning, match="GPU available but not used"):
Trainer(accelerator="cpu")


@RunIf(min_cuda_gpus=1)
def test_gpu_device_name():
for i in range(torch.cuda.device_count()):
assert torch.cuda.get_device_name(i) == CUDAAccelerator.device_name(i)

with torch.device("cuda:0"):
assert torch.cuda.get_device_name(0) == CUDAAccelerator.device_name()


def test_gpu_device_name_no_gpu(cuda_count_0):
assert CUDAAccelerator.device_name() == ""
11 changes: 11 additions & 0 deletions tests/tests_pytorch/accelerators/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from collections import namedtuple
from unittest import mock

import pytest
import torch
Expand All @@ -39,6 +40,16 @@ def test_mps_availability():
assert MPSAccelerator.is_available()


@RunIf(mps=True)
def test_mps_device_name():
assert MPSAccelerator.device_name() == "True (mps)"


def test_mps_device_name_not_available():
with mock.patch("torch.backends.mps.is_available", return_value=False):
assert MPSAccelerator.device_name() == ""


def test_warning_if_mps_not_used(mps_count_1):
with pytest.warns(UserWarning, match="GPU available but not used"):
Trainer(accelerator="cpu")
Expand Down
13 changes: 13 additions & 0 deletions tests/tests_pytorch/accelerators/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,19 @@ def test_warning_if_tpus_not_used(tpu_available):
Trainer(accelerator="cpu")


@RunIf(tpu=True)
def test_tpu_device_name():
from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1

if _XLA_GREATER_EQUAL_2_1:
from torch_xla._internal import tpu
else:
from torch_xla.experimental import tpu
import torch_xla.core.xla_env_vars as xenv

assert XLAAccelerator.device_name() == tpu.get_tpu_env()[xenv.ACCELERATOR_TYPE]


@pytest.mark.parametrize(
("devices", "expected_device_ids"),
[
Expand Down
6 changes: 6 additions & 0 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def thread_police_duuu_daaa_duuu_daaa():
def mock_cuda_count(monkeypatch, n: int) -> None:
monkeypatch.setattr(lightning.fabric.accelerators.cuda, "num_cuda_devices", lambda: n)
monkeypatch.setattr(lightning.pytorch.accelerators.cuda, "num_cuda_devices", lambda: n)
monkeypatch.setattr(torch.cuda, "get_device_name", lambda _: "Mocked CUDA Device")


@pytest.fixture
Expand Down Expand Up @@ -244,6 +245,11 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N
monkeypatch.setitem(sys.modules, "torch_xla", Mock())
monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock())
monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock())
monkeypatch.setattr(
lightning.pytorch.accelerators.xla.XLAAccelerator,
"device_name",
lambda *_: "Mocked TPU Device",
)


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/plugins/test_cluster_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_ranks_available_manual_strategy_selection(_, strategy_cls):
"""Test that the rank information is readily available after Trainer initialization."""
num_nodes = 2
for cluster, variables, expected in environment_combinations():
with mock.patch.dict(os.environ, variables):
with mock.patch.dict(os.environ, variables), mock.patch("torch.cuda.get_device_name", return_value="GPU"):
strategy = strategy_cls(
parallel_devices=[torch.device("cuda", 1), torch.device("cuda", 2)], cluster_environment=cluster
)
Expand Down
Loading