Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/fabric/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class _DeviceDtypeModuleMixin(Module):
def __init__(self) -> None:
super().__init__()
self._dtype: Union[str, torch.dtype] = torch.get_default_dtype()
self._device = torch.device("cpu")
self._device: torch.device = torch.get_default_device()

@property
def dtype(self) -> Union[str, torch.dtype]:
Expand Down
16 changes: 16 additions & 0 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,6 +2107,22 @@ def test_init_module_context(monkeypatch):
strategy.tensor_init_context.reset_mock()


@pytest.mark.parametrize(
("target_device", "accelerator", "devices"),
[
("cpu", "cpu", "auto"),
pytest.param("cuda:0", "gpu", [0], marks=RunIf(min_cuda_gpus=1)),
pytest.param("cuda:1", "gpu", [1], marks=RunIf(min_cuda_gpus=2)),
],
)
def test_init_module_device_type(target_device, accelerator, devices):
"""Test that the strategy returns the context manager for initializing the module."""
trainer = Trainer(accelerator=accelerator, devices=devices)
with trainer.init_module():
model = BoringModel()
assert model.device == torch.device(target_device)


def test_expand_home_trainer():
"""Test that the dirpath gets expanded if it contains `~`."""
home_root = Path.home()
Expand Down
Loading