Skip to content

Commit a64e1df

Browse files
rohitgr7carmocca
authored andcommitted
Fix fit loop restart logic to enable resume using the checkpoint (#12821)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 55f5e2d commit a64e1df

File tree

4 files changed

+16
-25
lines changed

4 files changed

+16
-25
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2222
- Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/PyTorchLightning/pytorch-lightning/pull/12653))
2323
- Enable mixed precision in `DDPFullyShardedStrategy` when `precision=16` ([#12965](https://github.com/PyTorchLightning/pytorch-lightning/pull/12965))
2424
- Fixed `TQDMProgressBar` reset and update to show correct time estimation ([#12889](https://github.com/PyTorchLightning/pytorch-lightning/pull/12889))
25+
- Fixed fit loop restart logic to enable resume using the checkpoint ([#12821](https://github.com/PyTorchLightning/pytorch-lightning/pull/12821)
2526

2627

2728
## [1.6.2] - 2022-04-27

pytorch_lightning/loops/fit_loop.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,10 @@ def running_loss(self) -> TensorRunningAccum:
123123

124124
@Loop.restarting.setter
125125
def restarting(self, restarting: bool) -> None:
126-
# if the last epoch completely finished, we are not actually restarting, we can check this to see if all
127-
# current values are equal
128-
values = (
129-
self.epoch_progress.current.ready,
130-
self.epoch_progress.current.started,
131-
self.epoch_progress.current.processed,
132-
)
133-
finished_before_on_train_end = any(v != self.epoch_progress.current.completed for v in values)
134-
restarting &= finished_before_on_train_end
126+
# if the last epoch completely finished, we are not actually restarting
127+
values = self.epoch_progress.current.ready, self.epoch_progress.current.started
128+
epoch_unfinished = any(v != self.epoch_progress.current.processed for v in values)
129+
restarting = restarting and epoch_unfinished or self._iteration_based_training()
135130
Loop.restarting.fset(self, restarting) # call the parent setter
136131

137132
@property
@@ -205,6 +200,10 @@ def reset(self) -> None:
205200

206201
def on_run_start(self) -> None: # type: ignore[override]
207202
"""Calls the ``on_train_start`` hook."""
203+
# update the current_epoch in-case of checkpoint reload
204+
if not self._iteration_based_training():
205+
self.epoch_progress.current.completed = self.epoch_progress.current.processed
206+
208207
# reset train dataloader and val dataloader
209208
self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)
210209

@@ -336,6 +335,9 @@ def _should_accumulate(self) -> bool:
336335
"""Whether the gradients should be accumulated."""
337336
return self.epoch_loop._should_accumulate()
338337

338+
def _iteration_based_training(self) -> bool:
339+
return self.trainer.max_steps != -1
340+
339341

340342
def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]:
341343
training_step_fx = getattr(trainer.lightning_module, "training_step")

tests/models/test_hooks.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -616,8 +616,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir):
616616
"state_dict": ANY,
617617
"loops": ANY,
618618
}
619-
saved_ckpt1 = {**loaded_ckpt, "global_step": 2, "epoch": 0}
620-
saved_ckpt2 = {**loaded_ckpt, "global_step": 4, "epoch": 1}
619+
saved_ckpt = {**loaded_ckpt, "global_step": 4, "epoch": 1}
621620
expected = [
622621
dict(name="Callback.on_init_start", args=(trainer,)),
623622
dict(name="Callback.on_init_end", args=(trainer,)),
@@ -647,23 +646,12 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir):
647646
dict(name="on_epoch_start"),
648647
dict(name="Callback.on_train_epoch_start", args=(trainer, model)),
649648
dict(name="on_train_epoch_start"),
650-
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
651-
dict(name="Callback.state_dict"),
652-
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt1)),
653-
dict(name="on_save_checkpoint", args=(saved_ckpt1,)),
654-
dict(name="on_train_epoch_end"),
655-
dict(name="Callback.on_epoch_end", args=(trainer, model)),
656-
dict(name="on_epoch_end"),
657-
dict(name="Callback.on_epoch_start", args=(trainer, model)),
658-
dict(name="on_epoch_start"),
659-
dict(name="Callback.on_train_epoch_start", args=(trainer, model)),
660-
dict(name="on_train_epoch_start"),
661649
*model._train_batch(trainer, model, 2, current_epoch=1, current_batch=0),
662650
dict(name="training_epoch_end", args=([dict(loss=ANY)] * 2,)),
663651
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
664652
dict(name="Callback.state_dict"),
665-
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt2)),
666-
dict(name="on_save_checkpoint", args=(saved_ckpt2,)),
653+
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)),
654+
dict(name="on_save_checkpoint", args=(saved_ckpt,)),
667655
dict(name="on_train_epoch_end"),
668656
dict(name="Callback.on_epoch_end", args=(trainer, model)),
669657
dict(name="on_epoch_end"),

tests/models/test_restore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def on_train_start(self):
199199
if self.trainer.state.fn == TrainerFn.TUNING:
200200
self._test_on_val_test_predict_tune_start()
201201
else:
202-
assert self.trainer.current_epoch == state_dict["epoch"]
202+
assert self.trainer.current_epoch == state_dict["epoch"] + 1
203203
assert self.trainer.global_step == state_dict["global_step"]
204204
assert self._check_model_state_dict()
205205
assert self._check_optimizers()

0 commit comments

Comments
 (0)