@@ -90,6 +90,7 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s
9090 self ._warning_cache = WarningCache ()
9191 self ._batches_that_stepped : int = 0
9292 self ._restart_stage = RestartStage .NONE
93+ self ._skip_next_val = False
9394
9495 @property
9596 def total_batch_idx (self ) -> int :
@@ -257,8 +258,15 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
257258
258259 """
259260 if self .restarting and self ._should_check_val_fx (data_fetcher ):
260- if self .val_loop .restarted_mid_evaluation or self .restarted_on_last :
261+ if self .val_loop .restarted_mid_evaluation :
262+ # Go back and finish running validation
261263 return
264+
265+ if self .restarted_on_last :
266+ # Avoid running validation again if we saved on last
267+ self ._skip_next_val = True
268+ return
269+
262270 # fast forward progress counters to end of validation
263271 self .val_loop .increment_progress_to_evaluation_end ()
264272
@@ -345,6 +353,11 @@ def on_advance_end(self, data_fetcher: _DataFetcher) -> None:
345353 # VALIDATE IF NEEDED
346354 # -----------------------------------------
347355 should_check_val = self ._should_check_val_fx (data_fetcher )
356+
357+ if self ._skip_next_val :
358+ should_check_val = False
359+ self ._skip_next_val = False
360+
348361 if should_check_val :
349362 # this needs to be set so the correct `trainer._active_loop` is picked
350363 self .trainer .validating = True
0 commit comments