Skip to content

Commit e4c2ac5

Browse files
committed
fix: set _DeviceDtypeModuleMixin _device from torch's default device function.
1 parent 663b6ce commit e4c2ac5

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

src/lightning/fabric/utilities/device_dtype_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class _DeviceDtypeModuleMixin(Module):
2525
def __init__(self) -> None:
2626
super().__init__()
2727
self._dtype: Union[str, torch.dtype] = torch.get_default_dtype()
28-
self._device = torch.device("cpu")
28+
self._device: torch.device = torch.get_default_device()
2929

3030
@property
3131
def dtype(self) -> Union[str, torch.dtype]:

tests/tests_pytorch/trainer/test_trainer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2107,6 +2107,22 @@ def test_init_module_context(monkeypatch):
21072107
strategy.tensor_init_context.reset_mock()
21082108

21092109

2110+
@pytest.mark.parametrize(
2111+
("target_device", "accelerator", "devices"),
2112+
[
2113+
("cpu", "cpu", "auto"),
2114+
pytest.param("cuda:0", "gpu", [0], marks=RunIf(min_cuda_gpus=1)),
2115+
pytest.param("cuda:1", "gpu", [1], marks=RunIf(min_cuda_gpus=2)),
2116+
],
2117+
)
2118+
def test_init_module_device_type(target_device, accelerator, devices):
2119+
"""Test that the strategy returns the context manager for initializing the module."""
2120+
trainer = Trainer(accelerator=accelerator, devices=devices)
2121+
with trainer.init_module():
2122+
model = BoringModel()
2123+
assert model.device == torch.device(target_device)
2124+
2125+
21102126
def test_expand_home_trainer():
21112127
"""Test that the dirpath gets expanded if it contains `~`."""
21122128
home_root = Path.home()

0 commit comments

Comments
 (0)