Skip to content

Commit 59e41de

Browse files
committed
Clean up
1 parent e64c200 commit 59e41de

File tree

4 files changed

+146
-73
lines changed

4 files changed

+146
-73
lines changed

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import shutil
1616
import sys
1717
from collections import ChainMap, OrderedDict, defaultdict
18+
from dataclasses import dataclass
1819
from typing import Any, DefaultDict, Iterable, Iterator, List, Optional, Tuple, Union
1920

2021
from lightning_utilities.core.apply_func import apply_to_collection
@@ -45,6 +46,12 @@
4546
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
4647

4748

49+
@dataclass
50+
class RestartStage:
51+
NONE = "none"
52+
RESTARTED_MID_EVALUATION = "restarted_mid_evaluation"
53+
54+
4855
class _EvaluationLoop(_Loop):
4956
"""Top-level loop where validation/testing starts."""
5057

@@ -73,6 +80,7 @@ def __init__(
7380
self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int)
7481
self._last_val_dl_reload_epoch = float("-inf")
7582
self._module_mode = _ModuleMode()
83+
self._restart_stage = RestartStage.NONE
7684

7785
@property
7886
def num_dataloaders(self) -> int:
@@ -137,7 +145,7 @@ def run(self) -> List[_OUT_DICT]:
137145
# this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
138146
break
139147
finally:
140-
self._restarting = False
148+
self.on_iteration_done()
141149
self._store_dataloader_outputs()
142150
return self.on_run_end()
143151

@@ -198,14 +206,23 @@ def setup_data(self) -> None:
198206
self._seen_batches_per_dataloader = defaultdict(int)
199207

200208
@property
201-
def restarting_mid_evaluation(self) -> bool:
202-
return (
209+
def restarted_mid_evaluation(self) -> bool:
210+
return self._restart_stage == RestartStage.RESTARTED_MID_EVALUATION
211+
212+
def update_restart_stage(self) -> None:
213+
if (
203214
self.restarting
204215
and self.batch_progress.total.started == self.batch_progress.total.ready
205216
and self.batch_progress.total.processed == self.batch_progress.total.started - 1
206217
and self.batch_progress.total.completed == self.batch_progress.total.processed
207-
)
218+
):
219+
self._restart_stage = RestartStage.RESTARTED_MID_EVALUATION
220+
else:
221+
self._restart_stage = RestartStage.NONE
208222

223+
def reset_restart_stage(self) -> None:
224+
self._restart_stage = RestartStage.NONE
225+
209226
def reset(self) -> None:
210227
"""Resets the internal state of the loop."""
211228
trainer = self.trainer

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 69 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
from dataclasses import dataclass
1516
from typing import Any, Dict, List, Optional, Union
1617

1718
import torch
@@ -45,6 +46,15 @@
4546
log = logging.getLogger(__name__)
4647

4748

49+
@dataclass
50+
class RestartStage:
51+
NONE = "none"
52+
RESTARTED_ON_EPOCH_START = "restarted_on_epoch_start"
53+
RESTARTED_MID_EPOCH = "restarted_mid_epoch"
54+
RESTARTED_ON_EPOCH_END = "restarted_on_epoch_end"
55+
RESUMED_ON_EPOCH_END = "resumed_on_epoch_end"
56+
57+
4858
class _FitLoop(_Loop):
4959
"""This loop is the top-level loop where training starts.
5060
@@ -97,6 +107,7 @@ def __init__(
97107
self._combined_loader_states_to_load: List[Dict[str, Any]] = []
98108
self._data_fetcher: Optional[_DataFetcher] = None
99109
self._last_train_dl_reload_epoch = float("-inf")
110+
self._restart_stage = RestartStage.NONE
100111

101112
@property
102113
def total_batch_idx(self) -> int:
@@ -204,9 +215,10 @@ def run(self) -> None:
204215
self.on_advance_start()
205216
self.advance()
206217
self.on_advance_end()
207-
self._restarting = False
208218
except StopIteration:
209219
break
220+
finally:
221+
self.on_iteration_done()
210222
self._restarting = False
211223
self.on_run_end()
212224

@@ -303,67 +315,89 @@ def setup_data(self) -> None:
303315
)
304316

305317
@property
306-
def restarting_on_epoch_start(self) -> bool:
307-
return (
318+
def restarted_on_epoch_start(self) -> bool:
319+
return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_START
320+
321+
@property
322+
def restarted_mid_epoch(self) -> bool:
323+
return self._restart_stage == RestartStage.RESTARTED_MID_EPOCH
324+
325+
@property
326+
def restarted_on_epoch_end(self) -> bool:
327+
return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_END
328+
329+
@property
330+
def resumed_on_epoch_end(self) -> bool:
331+
# This case happens when restarting from last without validation at
332+
# the end of epoch. In this case self.restarting is False.
333+
return self._restart_stage == RestartStage.RESUMED_ON_EPOCH_END
334+
335+
def update_restart_stage(self) -> None:
336+
if (
308337
self.restarting
309338
and self.epoch_progress.total.started == self.epoch_progress.total.ready - 1
310339
and self.epoch_progress.total.processed == self.epoch_progress.total.started
311340
and self.epoch_progress.total.completed == self.epoch_progress.total.processed
312-
)
313-
314-
@property
315-
def restarting_mid_epoch(self) -> bool:
316-
return (
341+
):
342+
self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_START
343+
elif (
317344
self.restarting
318345
and self.epoch_progress.total.started == self.epoch_progress.total.ready
319346
and self.epoch_progress.total.processed == self.epoch_progress.total.started - 1
320347
and self.epoch_progress.total.completed == self.epoch_progress.total.processed
321-
)
322-
323-
@property
324-
def restarting_on_epoch_end(self) -> bool:
325-
return (
348+
):
349+
self._restart_stage = RestartStage.RESTARTED_MID_EPOCH
350+
elif (
326351
self.restarting
327352
and self.epoch_progress.total.started == self.epoch_progress.total.ready
328353
and self.epoch_progress.total.processed == self.epoch_progress.total.started
329354
and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1
330-
)
331-
332-
@property
333-
def progress_at_epoch_end(self) -> bool:
334-
# TODO LUCA comment for restart last without val
335-
return (
336-
self.epoch_progress.total.started == self.epoch_progress.total.ready
355+
):
356+
self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_END
357+
elif (
358+
self._loaded_from_state_dict
359+
and self.epoch_progress.total.started == self.epoch_progress.total.ready
337360
and self.epoch_progress.total.processed == self.epoch_progress.total.started
338361
and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1
339-
)
362+
):
363+
self._restart_stage = RestartStage.RESUMED_ON_EPOCH_END
364+
else:
365+
self._restart_stage = RestartStage.NONE
366+
367+
self.epoch_loop.update_restart_stage()
368+
369+
def reset_restart_stage(self) -> None:
370+
self._restart_stage = RestartStage.NONE
340371

341372
def reset(self) -> None:
342373
"""Resets the internal state of this loop."""
343374
assert self.trainer.model is not None
344375
torch.set_grad_enabled(True)
345376

346-
self.epoch_loop.reset_restarting_states()
377+
self.update_restart_stage()
347378

348-
if self.restarting_on_epoch_start:
379+
if self.restarted_on_epoch_start:
349380
self.epoch_progress.reset_on_restart()
350381

351-
if self.progress_at_epoch_end:
382+
if self.resumed_on_epoch_end:
383+
# when restarting from last without validation at end of epoch,
384+
# self.restarting is False but it's still resuming
352385
self.epoch_progress.increment_completed()
353386

354-
# TODO LUCA: refactor restarting for fit_loop
355-
restarting_mid_epoch = self.restarting_mid_epoch
356-
357-
if (self.epoch_loop.restarting_on_train_batch_end
358-
and self.restarting_mid_epoch
359-
and self.epoch_loop.batch_progress.is_last_batch):
387+
if (
388+
self.epoch_loop.restarted_on_train_batch_end
389+
and self.restarted_mid_epoch
390+
and self.epoch_loop.batch_progress.is_last_batch
391+
):
360392
self.epoch_progress.increment_processed()
361393
self.epoch_progress.increment_completed()
362394

363-
if (self.epoch_loop.restarting_on_train_batch_end
395+
if (
396+
self.epoch_loop.restarted_on_train_batch_end
364397
and self.epoch_loop.batch_progress.is_last_batch
365-
and not restarting_mid_epoch
366-
and not self.epoch_loop.val_loop.batch_progress.is_last_batch):
398+
and not self.restarted_mid_epoch
399+
and not self.epoch_loop.val_loop.batch_progress.is_last_batch
400+
):
367401
self.epoch_progress.increment_completed()
368402

369403
def on_run_start(self) -> None:
@@ -396,8 +430,8 @@ def on_advance_start(self) -> None:
396430
for i, dl in enumerate(self._combined_loader.flattened):
397431
_set_sampler_epoch(dl, self.epoch_progress.current.processed)
398432

399-
if not self.restarting_mid_epoch and not self.restarting_on_epoch_end:
400-
if not self.restarting_on_epoch_start:
433+
if not self.restarted_mid_epoch and not self.restarted_on_epoch_end:
434+
if not self.restarted_on_epoch_start:
401435
self.epoch_progress.increment_ready()
402436

403437
call._call_callback_hooks(trainer, "on_train_epoch_start")

src/lightning/pytorch/loops/loop.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class _Loop:
2222

2323
def __init__(self, trainer: "pl.Trainer") -> None:
2424
self._restarting = False
25+
self._loaded_from_state_dict = False
2526
self.trainer = trainer
2627

2728
@property
@@ -37,6 +38,9 @@ def restarting(self, restarting: bool) -> None:
3738
if isinstance(loop, _Loop):
3839
loop.restarting = restarting
3940

41+
def reset_restart_stage(self) -> None:
42+
pass
43+
4044
def on_save_checkpoint(self) -> Dict:
4145
"""Called when saving a model checkpoint, use to persist loop state.
4246
@@ -82,6 +86,7 @@ def load_state_dict(
8286
if isinstance(v, _Loop):
8387
v.load_state_dict(state_dict.copy(), prefix + k + ".")
8488
self.restarting = True
89+
self._loaded_from_state_dict = True
8590

8691
def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None:
8792
for k, v in self.__dict__.items():
@@ -93,3 +98,8 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None:
9398
v.load_state_dict(state_dict[key])
9499
if prefix + "state_dict" in state_dict: # compatibility with old checkpoints
95100
self.on_load_checkpoint(state_dict[prefix + "state_dict"])
101+
102+
def on_iteration_done(self) -> None:
103+
self._restarting = False
104+
self._loaded_from_state_dict = False
105+
self.reset_restart_stage()

0 commit comments

Comments
 (0)