Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
[#21057](https://github.com/Lightning-AI/pytorch-lightning/pull/21057), [#21093](https://github.com/Lightning-AI/pytorch-lightning/pull/21093))


- Set `_DeviceDtypeModuleMixin._device` from torch's default device function ([#21164](https://github.com/Lightning-AI/pytorch-lightning/pull/21164))


### Fixed

- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))
Expand Down
4 changes: 3 additions & 1 deletion src/lightning/fabric/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ class _DeviceDtypeModuleMixin(Module):
def __init__(self) -> None:
super().__init__()
self._dtype: Union[str, torch.dtype] = torch.get_default_dtype()
self._device = torch.device("cpu")
# Workarounds from the original pytorch issue:
# https://github.com/pytorch/pytorch/issues/115333#issuecomment-1848449687
self._device = torch.get_default_device() if torch.__version__ >= "2.3.0" else torch.empty(0).device

@property
def dtype(self) -> Union[str, torch.dtype]:
Expand Down
21 changes: 21 additions & 0 deletions tests/tests_fabric/utilities/test_device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,27 @@ def test_submodules_device_and_dtype(dst_device_str, dst_type):
assert model.dtype == model.module.module.dtype == dst_type


@pytest.mark.parametrize(
"dst_device_str",
[
"cpu",
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
pytest.param("mps:0", marks=RunIf(mps=True)),
],
)
@pytest.mark.parametrize("dst_type", [torch.half, torch.float, torch.double])
def test_submodules_context_device_and_dtype(dst_device_str, dst_type):
if dst_device_str == "mps:0" and dst_type in (torch.half, torch.double):
pytest.skip("MPS does not yet support half and double.")

dst_device = torch.device(dst_device_str)
torch.set_default_dtype(dst_type)
with dst_device:
model = TopModule()
assert model.device == dst_device
assert model.dtype == dst_type


@pytest.mark.parametrize(
"device",
[
Expand Down
16 changes: 16 additions & 0 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,6 +2107,22 @@ def test_init_module_context(monkeypatch):
strategy.tensor_init_context.reset_mock()


@pytest.mark.parametrize(
("target_device", "accelerator", "devices"),
[
("cpu", "cpu", "auto"),
pytest.param("cuda:0", "gpu", [0], marks=RunIf(min_cuda_gpus=1)),
pytest.param("cuda:1", "gpu", [1], marks=RunIf(min_cuda_gpus=2)),
],
)
def test_init_module_device_type(target_device, accelerator, devices):
"""Test that the strategy returns the context manager for initializing the module."""
trainer = Trainer(accelerator=accelerator, devices=devices)
with trainer.init_module():
model = BoringModel()
assert model.device == torch.device(target_device)


def test_expand_home_trainer():
"""Test that the dirpath gets expanded if it contains `~`."""
home_root = Path.home()
Expand Down
Loading