Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed exporting `__version__` in `__init__` ([#19221](https://github.com/Lightning-AI/pytorch-lightning/pull/19221))


- Fixed an issue preventing the user to `model.load_from_checkpoint()` a GPU-trained model on a CPU-only machine with a CPU-only PyTorch installation ([#19024](https://github.com/Lightning-AI/lightning/pull/19024))


## [2.1.3] - 2023-12-21

### Changed
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def _load_from_checkpoint(

device = next((t for t in state_dict.values() if isinstance(t, torch.Tensor)), torch.tensor(0)).device
assert isinstance(model, pl.LightningModule)
if device.type == "cpu" and model.device.type == "cpu":
return model
return model.to(device)

raise NotImplementedError(f"Unsupported {cls}")
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/strategies/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def root_device(self) -> torch.device:
@override
def model_to_device(self) -> None:
assert self.model is not None, "self.model must be set before self.model.to()"
self.model.to(self.root_device)
if self.model.device.type != self.root_device.type:
self.model.to(self.root_device)

@property
@override
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,8 @@ def teardown(self) -> None:

if self.lightning_module is not None:
log.debug(f"{self.__class__.__name__}: moving model to CPU")
self.lightning_module.cpu()
if self.lightning_module.device.type != "cpu":
self.lightning_module.cpu()
self.precision_plugin.teardown()
assert self.accelerator is not None
self.accelerator.teardown()
Expand Down
Loading