Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue preventing the user to `Trainer.save_checkpoint()` an FSDP model when `Trainer.test/validate/predict()` ran after `Trainer.fit()` ([#18992](https://github.com/Lightning-AI/lightning/issues/18992))


- Fixed an issue preventing the user to `model.load_from_checkpoint()` a GPU-trained model on a CPU-only machine with a CPU-only PyTorch installation.


## [2.1.0] - 2023-10-11

### Added
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def _load_from_checkpoint(

device = next((t for t in state_dict.values() if isinstance(t, torch.Tensor)), torch.tensor(0)).device
assert isinstance(model, pl.LightningModule)
if device.type == "cpu" and model.device.type == "cpu":
return model
return model.to(device)

raise NotImplementedError(f"Unsupported {cls}")
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/strategies/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def root_device(self) -> torch.device:
@override
def model_to_device(self) -> None:
assert self.model is not None, "self.model must be set before self.model.to()"
self.model.to(self.root_device)
if self.model.device.type != self.root_device.type:
self.model.to(self.root_device)

@override
def setup(self, trainer: pl.Trainer) -> None:
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,8 @@ def teardown(self) -> None:

if self.lightning_module is not None:
log.debug(f"{self.__class__.__name__}: moving model to CPU")
self.lightning_module.cpu()
if self.lightning_module.device.type != "cpu":
self.lightning_module.cpu()
self.precision_plugin.teardown()
assert self.accelerator is not None
self.accelerator.teardown()
Expand Down