Skip to content

Commit e64c200

Browse files
committed
Fix edge cases and start from last with and without val
1 parent 0012dcb commit e64c200

File tree

4 files changed

+181
-21
lines changed

4 files changed

+181
-21
lines changed

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -206,15 +206,6 @@ def restarting_mid_evaluation(self) -> bool:
206206
and self.batch_progress.total.completed == self.batch_progress.total.processed
207207
)
208208

209-
@property
210-
def restarting_on_evaluation_end(self) -> bool:
211-
return (
212-
self.restarting
213-
and self.batch_progress.total.started == self.batch_progress.total.ready
214-
and self.batch_progress.total.processed == self.batch_progress.total.started
215-
and self.batch_progress.total.completed == self.batch_progress.total.processed - 1
216-
)
217-
218209
def reset(self) -> None:
219210
"""Resets the internal state of the loop."""
220211
trainer = self.trainer

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,14 +329,43 @@ def restarting_on_epoch_end(self) -> bool:
329329
and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1
330330
)
331331

332+
@property
333+
def progress_at_epoch_end(self) -> bool:
334+
# TODO LUCA comment for restart last without val
335+
return (
336+
self.epoch_progress.total.started == self.epoch_progress.total.ready
337+
and self.epoch_progress.total.processed == self.epoch_progress.total.started
338+
and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1
339+
)
340+
332341
def reset(self) -> None:
333342
"""Resets the internal state of this loop."""
334343
assert self.trainer.model is not None
335344
torch.set_grad_enabled(True)
336345

346+
self.epoch_loop.reset_restarting_states()
347+
337348
if self.restarting_on_epoch_start:
338349
self.epoch_progress.reset_on_restart()
339350

351+
if self.progress_at_epoch_end:
352+
self.epoch_progress.increment_completed()
353+
354+
# TODO LUCA: refactor restarting for fit_loop
355+
restarting_mid_epoch = self.restarting_mid_epoch
356+
357+
if (self.epoch_loop.restarting_on_train_batch_end
358+
and self.restarting_mid_epoch
359+
and self.epoch_loop.batch_progress.is_last_batch):
360+
self.epoch_progress.increment_processed()
361+
self.epoch_progress.increment_completed()
362+
363+
if (self.epoch_loop.restarting_on_train_batch_end
364+
and self.epoch_loop.batch_progress.is_last_batch
365+
and not restarting_mid_epoch
366+
and not self.epoch_loop.val_loop.batch_progress.is_last_batch):
367+
self.epoch_progress.increment_completed()
368+
340369
def on_run_start(self) -> None:
341370
"""Calls the ``on_train_start`` hook."""
342371
# update the current_epoch in-case of checkpoint reload

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s
8181
self._results = _ResultCollection(training=True)
8282
self._warning_cache = WarningCache()
8383
self._batches_that_stepped: int = 0
84+
self._restarting_on_train_batch_end: bool = None
85+
self._restarting_on_last: bool = None
8486

8587
@property
8688
def total_batch_idx(self) -> int:
@@ -146,15 +148,43 @@ def run(self, data_fetcher: _DataFetcher) -> None:
146148

147149
@property
148150
def restarting_on_train_batch_end(self) -> bool:
149-
return (
150-
self.restarting
151-
and self.batch_progress.total.started == self.batch_progress.total.ready
152-
and self.batch_progress.total.processed == self.batch_progress.total.started
153-
and self.batch_progress.total.completed == self.batch_progress.total.processed - 1
154-
)
151+
if self._restarting_on_train_batch_end is None:
152+
self._restarting_on_train_batch_end = (
153+
self.restarting
154+
and self.batch_progress.total.started == self.batch_progress.total.ready
155+
and self.batch_progress.total.processed == self.batch_progress.total.started
156+
and self.batch_progress.total.completed == self.batch_progress.total.processed - 1
157+
)
158+
return self._restarting_on_train_batch_end
159+
160+
@property
161+
def restarting_on_last(self) -> bool:
162+
if self._restarting_on_last is None:
163+
self._restarting_on_last = (
164+
self.restarting
165+
and self.batch_progress.total.started == self.batch_progress.total.ready
166+
and self.batch_progress.total.processed == self.batch_progress.total.started
167+
and self.batch_progress.total.completed == self.batch_progress.total.processed
168+
)
169+
return self._restarting_on_last
170+
171+
def reset_restarting_states(self) -> None:
172+
self._restarting_on_train_batch_end = None
173+
self._restarting_on_last = None
174+
self.restarting_on_train_batch_end
175+
self.restarting_on_last
155176

156177
def reset(self) -> None:
178+
self.reset_restarting_states()
157179
"""Resets the internal state of the loop for a new run."""
180+
if self.restarting and not self._should_accumulate():
181+
# batches_that_stepped is never set prior to saving a checkpoint, even when saving
182+
# happens on_validation_end
183+
# we could set it in the checkpoint but we prefer to keep checkpoints backward compatible
184+
if self.restarting_on_train_batch_end or not self.restarting_on_last:
185+
# if not self.restarting_on_train_batch_end and not self.restarting_on_last:
186+
self._batches_that_stepped += 1
187+
158188
if self.restarting_on_train_batch_end:
159189
self.batch_progress.increment_completed()
160190
# handle situation in which save happened on_train_batch_end and epoch is at end
@@ -163,8 +193,6 @@ def reset(self) -> None:
163193
self.scheduler_progress.reset_on_run()
164194
self.automatic_optimization.optim_progress.reset_on_run()
165195
self.val_loop.batch_progress.total.reset()
166-
if not self._should_accumulate():
167-
self._batches_that_stepped += 1
168196

169197
if self.restarting:
170198
self.batch_progress.reset_on_restart()
@@ -217,7 +245,7 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
217245
218246
"""
219247
if self.restarting and self._should_check_val_fx(data_fetcher):
220-
if self.val_loop.restarting_mid_evaluation:
248+
if self.val_loop.restarting_mid_evaluation or self.restarting_on_last:
221249
return
222250
# fast forward progress counters to end of validation
223251
self.val_loop.increment_progress_to_evaluation_end()

tests/tests_pytorch/loops/test_loops.py

Lines changed: 115 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,16 +604,24 @@ def test_fit_loop_reset(tmp_path):
604604

605605
# we load exactly what was saved - no reset yet
606606
fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"])
607+
608+
assert fit_loop.restarting
609+
assert fit_loop.epoch_progress.total.ready == 1
610+
assert fit_loop.epoch_progress.total.completed == 0
611+
assert fit_loop.epoch_progress.current.ready == 1
612+
assert fit_loop.epoch_progress.current.completed == 0
613+
607614
# resetting from a end-of-epoch checkpoint SHOULD reset the current counters to 0
608615
fit_loop.reset()
609616
epoch_loop.reset()
610617

611618
# resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0
619+
# since we are restarting at the end of epoch, we need to see `completed` being updated after reset
612620
assert fit_loop.restarting
613621
assert fit_loop.epoch_progress.total.ready == 1
614-
assert fit_loop.epoch_progress.total.completed == 0
622+
assert fit_loop.epoch_progress.total.completed == 1
615623
assert fit_loop.epoch_progress.current.ready == 1
616-
assert fit_loop.epoch_progress.current.completed == 0
624+
assert fit_loop.epoch_progress.current.completed == 1
617625

618626
# however it should increment completed batch progress, since it was saved immediately prior
619627
assert epoch_loop.restarting
@@ -704,6 +712,7 @@ def test_restart_parity(tmp_path):
704712
callbacks=[checkpoint_callback],
705713
logger=False,
706714
enable_model_summary=False,
715+
enable_progress_bar=False,
707716
)
708717
trainer.fit(model)
709718
loss = model.last_loss
@@ -715,6 +724,7 @@ def test_restart_parity(tmp_path):
715724
callbacks=[checkpoint_callback],
716725
logger=False,
717726
enable_model_summary=False,
727+
enable_progress_bar=False,
718728
)
719729
trainer.fit(model, ckpt_path=str(tmp_path / "epoch=0-step=2.ckpt"))
720730
loss_v1 = model.last_loss
@@ -749,7 +759,7 @@ def test_restart_parity(tmp_path):
749759
assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {}
750760

751761

752-
def test_restart_parity_with_val(tmp_path):
762+
def test_restart_with_val_parity(tmp_path):
753763
model = PredictableBoringModel()
754764
checkpoint_callback = ModelCheckpoint(
755765
dirpath=tmp_path,
@@ -814,6 +824,108 @@ def test_restart_parity_with_val(tmp_path):
814824
assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {}
815825

816826

827+
def test_restart_from_last_parity(tmp_path):
828+
model = PredictableBoringModel()
829+
checkpoint_callback = ModelCheckpoint(
830+
dirpath=tmp_path,
831+
save_last=True,
832+
save_top_k=-1,
833+
)
834+
835+
trainer = Trainer(
836+
default_root_dir=tmp_path,
837+
limit_train_batches=2,
838+
max_epochs=4,
839+
callbacks=[checkpoint_callback],
840+
logger=False,
841+
enable_model_summary=False,
842+
enable_progress_bar=False,
843+
)
844+
trainer.fit(model)
845+
846+
last_ckpt_1 = torch.load(str(tmp_path / "last.ckpt"), weights_only=True)
847+
848+
trainer = Trainer(
849+
default_root_dir=tmp_path,
850+
limit_train_batches=2,
851+
max_epochs=2,
852+
callbacks=[checkpoint_callback],
853+
logger=False,
854+
enable_model_summary=False,
855+
enable_progress_bar=False,
856+
)
857+
trainer.fit(model)
858+
859+
trainer = Trainer(
860+
default_root_dir=tmp_path,
861+
limit_train_batches=2,
862+
max_epochs=4,
863+
callbacks=[checkpoint_callback],
864+
logger=False,
865+
enable_model_summary=False,
866+
enable_progress_bar=False,
867+
)
868+
trainer.fit(model, ckpt_path=str(tmp_path / "last.ckpt"))
869+
870+
last_ckpt_2 = torch.load(str(tmp_path / "last.ckpt"), weights_only=True)
871+
872+
assert compare_state_dicts(last_ckpt_1["loops"], last_ckpt_2["loops"]) == {}
873+
874+
875+
def test_restart_from_last_with_val_parity(tmp_path):
876+
model = PredictableBoringModel()
877+
checkpoint_callback = ModelCheckpoint(
878+
dirpath=tmp_path,
879+
save_last=True,
880+
save_top_k=-1,
881+
)
882+
883+
trainer = Trainer(
884+
default_root_dir=tmp_path,
885+
limit_train_batches=2,
886+
max_epochs=4,
887+
callbacks=[checkpoint_callback],
888+
logger=False,
889+
enable_model_summary=False,
890+
enable_progress_bar=False,
891+
limit_val_batches=2,
892+
val_check_interval=2,
893+
)
894+
trainer.fit(model)
895+
896+
last_ckpt_1 = torch.load(str(tmp_path / "last.ckpt"), weights_only=True)
897+
898+
trainer = Trainer(
899+
default_root_dir=tmp_path,
900+
limit_train_batches=2,
901+
max_epochs=2,
902+
callbacks=[checkpoint_callback],
903+
logger=False,
904+
enable_model_summary=False,
905+
enable_progress_bar=False,
906+
limit_val_batches=2,
907+
val_check_interval=2,
908+
)
909+
trainer.fit(model)
910+
911+
trainer = Trainer(
912+
default_root_dir=tmp_path,
913+
limit_train_batches=2,
914+
max_epochs=4,
915+
callbacks=[checkpoint_callback],
916+
logger=False,
917+
enable_model_summary=False,
918+
enable_progress_bar=False,
919+
limit_val_batches=2,
920+
val_check_interval=2,
921+
)
922+
trainer.fit(model, ckpt_path=str(tmp_path / "last.ckpt"))
923+
924+
last_ckpt_2 = torch.load(str(tmp_path / "last.ckpt"), weights_only=True)
925+
926+
assert compare_state_dicts(last_ckpt_1["loops"], last_ckpt_2["loops"]) == {}
927+
928+
817929
@pytest.mark.parametrize(
818930
("train_datasets", "val_datasets"),
819931
[([RandomDataset], [RandomDataset]), ([RandomDataset], [RandomDataset, RandomDataset])],

0 commit comments

Comments
 (0)