Skip to content

Commit dd0c40e

Browse files
authored
Fix current epoch value override on restart (#12429)
1 parent b2e98d6 commit dd0c40e

File tree

4 files changed

+19
-21
lines changed

4 files changed

+19
-21
lines changed

pytorch_lightning/loops/fit_loop.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,6 @@ def restarting(self, restarting: bool) -> None:
131131
self.epoch_progress.current.processed,
132132
)
133133
finished_before_on_train_end = any(v != self.epoch_progress.current.completed for v in values)
134-
if finished_before_on_train_end:
135-
self.epoch_progress.current.completed = self.epoch_progress.current.processed
136134
restarting &= finished_before_on_train_end
137135
Loop.restarting.fset(self, restarting) # call the parent setter
138136

@@ -168,6 +166,9 @@ def done(self) -> bool:
168166
# `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
169167
# we use it here because the checkpoint data won't have `completed` increased yet
170168
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
169+
if stop_epochs:
170+
# in case they are not equal, override so `trainer.current_epoch` has the expected value
171+
self.epoch_progress.current.completed = self.epoch_progress.current.processed
171172

172173
should_stop = False
173174
if self.trainer.should_stop:

tests/models/test_hooks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
617617
"state_dict": ANY,
618618
"loops": ANY,
619619
}
620-
saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload, "epoch": 1}
620+
saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload}
621621
expected = [
622622
dict(name="Callback.on_init_start", args=(trainer,)),
623623
dict(name="Callback.on_init_end", args=(trainer,)),
@@ -647,7 +647,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
647647
dict(name="on_epoch_start"),
648648
dict(name="Callback.on_train_epoch_start", args=(trainer, model)),
649649
dict(name="on_train_epoch_start"),
650-
*model._train_batch(trainer, model, steps_after_reload, current_batch=1, current_epoch=1),
650+
*model._train_batch(trainer, model, steps_after_reload, current_batch=1),
651651
dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
652652
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
653653
dict(name="Callback.state_dict"),

tests/models/test_restore.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,7 @@ def on_train_start(self):
199199
if self.trainer.state.fn == TrainerFn.TUNING:
200200
self._test_on_val_test_predict_tune_start()
201201
else:
202-
# `-1` because this checkpoint is saved `on_train_epoch_end` which is considered part of the epoch so
203-
# the `current_epoch` count has not been increased yet
204-
assert self.trainer.current_epoch - 1 == state_dict["epoch"]
202+
assert self.trainer.current_epoch == state_dict["epoch"]
205203
assert self.trainer.global_step == state_dict["global_step"]
206204
assert self._check_model_state_dict()
207205
assert self._check_optimizers()

tests/trainer/test_trainer.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def test_model_freeze_unfreeze():
399399
assert param.requires_grad
400400

401401

402-
@pytest.mark.xfail(reason="FIXME(@carmocca): this test wasn't running and is now broken")
402+
# TODO: move to `test/models/test_restore.py`
403403
@pytest.mark.parametrize("url_ckpt", [True, False])
404404
def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
405405
"""Verify resuming from checkpoint runs the right number of epochs."""
@@ -422,11 +422,12 @@ def on_load_checkpoint(self, _):
422422
self.num_on_load_checkpoint_called += 1
423423

424424
model = TestModel()
425+
max_epochs = 2
425426
trainer = Trainer(
426-
max_epochs=2,
427+
max_epochs=max_epochs,
427428
limit_train_batches=0.65,
428429
limit_val_batches=1,
429-
callbacks=[ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)],
430+
callbacks=ModelCheckpoint(dirpath=tmpdir, save_top_k=-1),
430431
default_root_dir=tmpdir,
431432
val_check_interval=1.0,
432433
enable_progress_bar=False,
@@ -435,27 +436,25 @@ def on_load_checkpoint(self, _):
435436
)
436437
trainer.fit(model)
437438

438-
assert model.num_epochs_end_seen == 2
439-
assert model.num_batches_seen == trainer.num_training_batches * 2
439+
assert model.num_epochs_end_seen == max_epochs
440+
assert model.num_batches_seen == trainer.num_training_batches * max_epochs == trainer.global_step
440441
assert model.num_on_load_checkpoint_called == 0
441442

442-
# Other checkpoints can be uncommented if/when resuming mid-epoch is supported
443-
checkpoints = Path(trainer.checkpoint_callback.dirpath).glob("*.ckpt")
443+
checkpoints = set(Path(trainer.checkpoint_callback.dirpath).glob("*.ckpt"))
444444
if url_ckpt:
445445
# transform local paths into url checkpoints
446446
ip, port = tmpdir_server
447447
checkpoints = [f"http://{ip}:{port}/" + ckpt.name for ckpt in checkpoints]
448448

449-
assert checkpoints
449+
assert len(checkpoints) == max_epochs
450450
for ckpt in checkpoints:
451-
next_model = TestModel()
451+
model = TestModel()
452452
state = pl_load(ckpt)
453-
454453
# Resume training
455-
new_trainer = Trainer(default_root_dir=tmpdir, max_epochs=2)
456-
new_trainer.fit(next_model, ckpt_path=ckpt)
457-
assert state["global_step"] + next_model.num_batches_seen == trainer.num_training_batches * trainer.max_epochs
458-
assert next_model.num_on_load_checkpoint_called == 1
454+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, enable_progress_bar=False)
455+
trainer.fit(model, ckpt_path=ckpt)
456+
assert state["global_step"] + model.num_batches_seen == trainer.global_step
457+
assert model.num_on_load_checkpoint_called == 1
459458

460459

461460
def test_trainer_max_steps_and_epochs(tmpdir):

0 commit comments

Comments
 (0)