Skip to content

Commit e17dd29

Browse files
circlecrystalcarmocca
authored andcommitted
bug fix: restore_optimizers correctly handles non-mapping values in optimizer.state.values() (#11757)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent a971a6a commit e17dd29

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Fixed an issue to avoid validation loop run on restart ([#11552](https://github.com/PyTorchLightning/pytorch-lightning/pull/11552))
1313
- The Rich progress bar now correctly shows the `on_epoch` logged values on train epoch end ([#11689](https://github.com/PyTorchLightning/pytorch-lightning/pull/11689))
1414
- Fixed an issue to make the `step` argument in `WandbLogger.log_image` work ([#11716](https://github.com/PyTorchLightning/pytorch-lightning/pull/11716))
15+
- Fixed `restore_optimizers` for mapping states ([#11757](https://github.com/PyTorchLightning/pytorch-lightning/pull/11757))
1516

1617

1718
## [1.5.9] - 2022-01-18

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,13 @@ def restore_optimizers(self) -> None:
279279
# move optimizer to GPU 1 weight at a time
280280
# avoids OOM
281281
if self.trainer.root_gpu is not None:
282-
for state in optimizer.state.values():
283-
for k, v in state.items():
284-
if isinstance(v, torch.Tensor):
285-
state[k] = v.cuda(self.trainer.root_gpu)
282+
for param, state in optimizer.state.items():
283+
if isinstance(state, dict):
284+
for k, v in state.items():
285+
if isinstance(v, torch.Tensor):
286+
state[k] = v.cuda(self.trainer.root_gpu)
287+
elif isinstance(state, torch.Tensor):
288+
optimizer.state[param] = state.cuda(self.trainer.root_gpu)
286289

287290
def restore_lr_schedulers(self) -> None:
288291
"""Restores the learning rate scheduler states from the pre-loaded checkpoint."""

0 commit comments

Comments
 (0)