Skip to content

Commit 9358898

Browse files
authored
Ensure restarting from checkpoints leads to consistent internal counters (#20379)
* Fix checkpoint progress for fit loop and batch loop * Check loss parity * Rename test * Fix validation loop handling on restart * Fix loop reset test * Avoid skipping to val end if saved mid validation * Fix type checks in compare state dicts * Fix edge cases and start from last with and without val * Clean up * Formatting * Avoid running validation when restarting from last * Fix type annotations * Fix formatting * Ensure int max_batch * Fix condition on batches that stepped * Remove expected on_train_epoch_start when restarting mid epoch
1 parent 7038b8d commit 9358898

File tree

7 files changed

+584
-32
lines changed

7 files changed

+584
-32
lines changed

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import shutil
1616
import sys
1717
from collections import ChainMap, OrderedDict, defaultdict
18+
from dataclasses import dataclass
1819
from typing import Any, DefaultDict, Iterable, Iterator, List, Optional, Tuple, Union
1920

2021
from lightning_utilities.core.apply_func import apply_to_collection
@@ -45,6 +46,12 @@
4546
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
4647

4748

49+
@dataclass
50+
class RestartStage:
51+
NONE = "none"
52+
RESTARTED_MID_EVALUATION = "restarted_mid_evaluation"
53+
54+
4855
class _EvaluationLoop(_Loop):
4956
"""Top-level loop where validation/testing starts."""
5057

@@ -73,6 +80,7 @@ def __init__(
7380
self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int)
7481
self._last_val_dl_reload_epoch = float("-inf")
7582
self._module_mode = _ModuleMode()
83+
self._restart_stage = RestartStage.NONE
7684

7785
@property
7886
def num_dataloaders(self) -> int:
@@ -137,7 +145,7 @@ def run(self) -> List[_OUT_DICT]:
137145
# this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
138146
break
139147
finally:
140-
self._restarting = False
148+
self.on_iteration_done()
141149
self._store_dataloader_outputs()
142150
return self.on_run_end()
143151

@@ -197,6 +205,24 @@ def setup_data(self) -> None:
197205
# this depends on the data used, so reset it too
198206
self._seen_batches_per_dataloader = defaultdict(int)
199207

208+
@property
209+
def restarted_mid_evaluation(self) -> bool:
210+
return self._restart_stage == RestartStage.RESTARTED_MID_EVALUATION
211+
212+
def update_restart_stage(self) -> None:
213+
if (
214+
self.restarting
215+
and self.batch_progress.total.started == self.batch_progress.total.ready
216+
and self.batch_progress.total.processed == self.batch_progress.total.started - 1
217+
and self.batch_progress.total.completed == self.batch_progress.total.processed
218+
):
219+
self._restart_stage = RestartStage.RESTARTED_MID_EVALUATION
220+
else:
221+
self._restart_stage = RestartStage.NONE
222+
223+
def reset_restart_stage(self) -> None:
224+
self._restart_stage = RestartStage.NONE
225+
200226
def reset(self) -> None:
201227
"""Resets the internal state of the loop."""
202228
trainer = self.trainer
@@ -236,6 +262,16 @@ def reset(self) -> None:
236262
data_fetcher._stop_profiler = self._on_after_fetch
237263
self._data_fetcher = data_fetcher
238264

265+
def increment_progress_to_evaluation_end(self) -> None:
266+
self.setup_data()
267+
if self.skip:
268+
return
269+
self.reset()
270+
max_batch = int(max(self.max_batches))
271+
if max_batch == -1:
272+
return
273+
self.batch_progress.increment_by(max_batch, True)
274+
239275
def on_run_start(self) -> None:
240276
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
241277
hooks."""

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 99 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
from dataclasses import dataclass
1516
from typing import Any, Dict, List, Optional, Union
1617

1718
import torch
@@ -45,6 +46,15 @@
4546
log = logging.getLogger(__name__)
4647

4748

49+
@dataclass
50+
class RestartStage:
51+
NONE = "none"
52+
RESTARTED_ON_EPOCH_START = "restarted_on_epoch_start"
53+
RESTARTED_MID_EPOCH = "restarted_mid_epoch"
54+
RESTARTED_ON_EPOCH_END = "restarted_on_epoch_end"
55+
RESUMED_ON_EPOCH_END = "resumed_on_epoch_end"
56+
57+
4858
class _FitLoop(_Loop):
4959
"""This loop is the top-level loop where training starts.
5060
@@ -97,6 +107,7 @@ def __init__(
97107
self._combined_loader_states_to_load: List[Dict[str, Any]] = []
98108
self._data_fetcher: Optional[_DataFetcher] = None
99109
self._last_train_dl_reload_epoch = float("-inf")
110+
self._restart_stage = RestartStage.NONE
100111

101112
@property
102113
def total_batch_idx(self) -> int:
@@ -204,9 +215,10 @@ def run(self) -> None:
204215
self.on_advance_start()
205216
self.advance()
206217
self.on_advance_end()
207-
self._restarting = False
208218
except StopIteration:
209219
break
220+
finally:
221+
self.on_iteration_done()
210222
self._restarting = False
211223
self.on_run_end()
212224

@@ -302,14 +314,92 @@ def setup_data(self) -> None:
302314
category=PossibleUserWarning,
303315
)
304316

317+
@property
318+
def restarted_on_epoch_start(self) -> bool:
319+
return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_START
320+
321+
@property
322+
def restarted_mid_epoch(self) -> bool:
323+
return self._restart_stage == RestartStage.RESTARTED_MID_EPOCH
324+
325+
@property
326+
def restarted_on_epoch_end(self) -> bool:
327+
return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_END
328+
329+
@property
330+
def resumed_on_epoch_end(self) -> bool:
331+
# This case happens when restarting from last without validation at
332+
# the end of epoch. In this case self.restarting is False.
333+
return self._restart_stage == RestartStage.RESUMED_ON_EPOCH_END
334+
335+
def update_restart_stage(self) -> None:
336+
if (
337+
self.restarting
338+
and self.epoch_progress.total.started == self.epoch_progress.total.ready - 1
339+
and self.epoch_progress.total.processed == self.epoch_progress.total.started
340+
and self.epoch_progress.total.completed == self.epoch_progress.total.processed
341+
):
342+
self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_START
343+
elif (
344+
self.restarting
345+
and self.epoch_progress.total.started == self.epoch_progress.total.ready
346+
and self.epoch_progress.total.processed == self.epoch_progress.total.started - 1
347+
and self.epoch_progress.total.completed == self.epoch_progress.total.processed
348+
):
349+
self._restart_stage = RestartStage.RESTARTED_MID_EPOCH
350+
elif (
351+
self.restarting
352+
and self.epoch_progress.total.started == self.epoch_progress.total.ready
353+
and self.epoch_progress.total.processed == self.epoch_progress.total.started
354+
and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1
355+
):
356+
self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_END
357+
elif (
358+
self._loaded_from_state_dict
359+
and self.epoch_progress.total.started == self.epoch_progress.total.ready
360+
and self.epoch_progress.total.processed == self.epoch_progress.total.started
361+
and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1
362+
):
363+
self._restart_stage = RestartStage.RESUMED_ON_EPOCH_END
364+
else:
365+
self._restart_stage = RestartStage.NONE
366+
367+
self.epoch_loop.update_restart_stage()
368+
369+
def reset_restart_stage(self) -> None:
370+
self._restart_stage = RestartStage.NONE
371+
305372
def reset(self) -> None:
306373
"""Resets the internal state of this loop."""
307374
assert self.trainer.model is not None
308375
torch.set_grad_enabled(True)
309376

310-
if self.restarting:
377+
self.update_restart_stage()
378+
379+
if self.restarted_on_epoch_start:
311380
self.epoch_progress.reset_on_restart()
312381

382+
if self.resumed_on_epoch_end:
383+
# when restarting from last without validation at end of epoch,
384+
# self.restarting is False but it's still resuming
385+
self.epoch_progress.increment_completed()
386+
387+
if (
388+
self.epoch_loop.restarted_on_train_batch_end
389+
and self.restarted_mid_epoch
390+
and self.epoch_loop.batch_progress.is_last_batch
391+
):
392+
self.epoch_progress.increment_processed()
393+
self.epoch_progress.increment_completed()
394+
395+
if (
396+
self.epoch_loop.restarted_on_train_batch_end
397+
and self.epoch_loop.batch_progress.is_last_batch
398+
and not self.restarted_mid_epoch
399+
and not self.epoch_loop.val_loop.batch_progress.is_last_batch
400+
):
401+
self.epoch_progress.increment_completed()
402+
313403
def on_run_start(self) -> None:
314404
"""Calls the ``on_train_start`` hook."""
315405
# update the current_epoch in-case of checkpoint reload
@@ -340,12 +430,14 @@ def on_advance_start(self) -> None:
340430
for i, dl in enumerate(self._combined_loader.flattened):
341431
_set_sampler_epoch(dl, self.epoch_progress.current.processed)
342432

343-
self.epoch_progress.increment_ready()
433+
if not self.restarted_mid_epoch and not self.restarted_on_epoch_end:
434+
if not self.restarted_on_epoch_start:
435+
self.epoch_progress.increment_ready()
344436

345-
call._call_callback_hooks(trainer, "on_train_epoch_start")
346-
call._call_lightning_module_hook(trainer, "on_train_epoch_start")
437+
call._call_callback_hooks(trainer, "on_train_epoch_start")
438+
call._call_lightning_module_hook(trainer, "on_train_epoch_start")
347439

348-
self.epoch_progress.increment_started()
440+
self.epoch_progress.increment_started()
349441

350442
def advance(self) -> None:
351443
"""Runs one whole epoch."""
@@ -379,8 +471,7 @@ def on_advance_end(self) -> None:
379471

380472
trainer._logger_connector.on_epoch_end()
381473

382-
if self.epoch_loop._num_ready_batches_reached():
383-
# if we are restarting and the above condition holds, it's because we are reloading an epoch-end checkpoint.
474+
if not self.restarting and self.epoch_loop._num_ready_batches_reached():
384475
# since metric-based schedulers require access to metrics and those are not currently saved in the
385476
# checkpoint, the plateau schedulers shouldn't be updated
386477
self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting)

src/lightning/pytorch/loops/loop.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class _Loop:
2222

2323
def __init__(self, trainer: "pl.Trainer") -> None:
2424
self._restarting = False
25+
self._loaded_from_state_dict = False
2526
self.trainer = trainer
2627

2728
@property
@@ -37,6 +38,9 @@ def restarting(self, restarting: bool) -> None:
3738
if isinstance(loop, _Loop):
3839
loop.restarting = restarting
3940

41+
def reset_restart_stage(self) -> None:
42+
pass
43+
4044
def on_save_checkpoint(self) -> Dict:
4145
"""Called when saving a model checkpoint, use to persist loop state.
4246
@@ -82,6 +86,7 @@ def load_state_dict(
8286
if isinstance(v, _Loop):
8387
v.load_state_dict(state_dict.copy(), prefix + k + ".")
8488
self.restarting = True
89+
self._loaded_from_state_dict = True
8590

8691
def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None:
8792
for k, v in self.__dict__.items():
@@ -93,3 +98,8 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None:
9398
v.load_state_dict(state_dict[key])
9499
if prefix + "state_dict" in state_dict: # compatibility with old checkpoints
95100
self.on_load_checkpoint(state_dict[prefix + "state_dict"])
101+
102+
def on_iteration_done(self) -> None:
103+
self._restarting = False
104+
self._loaded_from_state_dict = False
105+
self.reset_restart_stage()

src/lightning/pytorch/loops/progress.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def reset_on_restart(self) -> None:
6868
"""
6969
self.ready = self.completed
7070

71+
def increment_by(self, n: int) -> None:
72+
self.ready += n
73+
self.completed += n
74+
7175

7276
@dataclass
7377
class _StartedTracker(_ReadyCompletedTracker):
@@ -94,6 +98,11 @@ def reset_on_restart(self) -> None:
9498
super().reset_on_restart()
9599
self.started = self.completed
96100

101+
@override
102+
def increment_by(self, n: int) -> None:
103+
super().increment_by(n)
104+
self.started += n
105+
97106

98107
@dataclass
99108
class _ProcessedTracker(_StartedTracker):
@@ -121,6 +130,11 @@ def reset_on_restart(self) -> None:
121130
super().reset_on_restart()
122131
self.processed = self.completed
123132

133+
@override
134+
def increment_by(self, n: int) -> None:
135+
super().increment_by(n)
136+
self.processed += n
137+
124138

125139
@dataclass
126140
class _Progress(_BaseProgress):
@@ -175,6 +189,10 @@ def reset_on_run(self) -> None:
175189
def reset_on_restart(self) -> None:
176190
self.current.reset_on_restart()
177191

192+
def increment_by(self, n: int) -> None:
193+
self.total.increment_by(n)
194+
self.current.increment_by(n)
195+
178196
@override
179197
def load_state_dict(self, state_dict: dict) -> None:
180198
self.total.load_state_dict(state_dict["total"])
@@ -206,6 +224,10 @@ def reset_on_run(self) -> None:
206224
super().reset_on_run()
207225
self.is_last_batch = False
208226

227+
def increment_by(self, n: int, is_last_batch: bool = False) -> None:
228+
super().increment_by(n)
229+
self.is_last_batch = is_last_batch
230+
209231
@override
210232
def load_state_dict(self, state_dict: dict) -> None:
211233
super().load_state_dict(state_dict)

0 commit comments

Comments
 (0)