Skip to content

Commit 9c58810

Browse files
committed
Avoid running validation when restarting from last
1 parent c3469be commit 9c58810

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s
9090
self._warning_cache = WarningCache()
9191
self._batches_that_stepped: int = 0
9292
self._restart_stage = RestartStage.NONE
93+
self._skip_next_val = False
9394

9495
@property
9596
def total_batch_idx(self) -> int:
@@ -257,8 +258,15 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
257258
258259
"""
259260
if self.restarting and self._should_check_val_fx(data_fetcher):
260-
if self.val_loop.restarted_mid_evaluation or self.restarted_on_last:
261+
if self.val_loop.restarted_mid_evaluation:
262+
# Go back and finish running validation
261263
return
264+
265+
if self.restarted_on_last:
266+
# Avoid running validation again if we saved on last
267+
self._skip_next_val = True
268+
return
269+
262270
# fast forward progress counters to end of validation
263271
self.val_loop.increment_progress_to_evaluation_end()
264272

@@ -345,6 +353,11 @@ def on_advance_end(self, data_fetcher: _DataFetcher) -> None:
345353
# VALIDATE IF NEEDED
346354
# -----------------------------------------
347355
should_check_val = self._should_check_val_fx(data_fetcher)
356+
357+
if self._skip_next_val:
358+
should_check_val = False
359+
self._skip_next_val = False
360+
348361
if should_check_val:
349362
# this needs to be set so the correct `trainer._active_loop` is picked
350363
self.trainer.validating = True

0 commit comments

Comments
 (0)