Skip to content

Commit e8bd2d7

Browse files
committed
Avoid skipping to val end if saved mid validation
1 parent 7d0b5a1 commit e8bd2d7

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,15 @@ def setup_data(self) -> None:
197197
# this depends on the data used, so reset it too
198198
self._seen_batches_per_dataloader = defaultdict(int)
199199

200+
@property
201+
def restarting_mid_evaluation(self) -> bool:
202+
return (
203+
self.restarting
204+
and self.batch_progress.total.started == self.batch_progress.total.ready
205+
and self.batch_progress.total.processed == self.batch_progress.total.started - 1
206+
and self.batch_progress.total.completed == self.batch_progress.total.processed
207+
)
208+
200209
@property
201210
def restarting_on_evaluation_end(self) -> bool:
202211
return (

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
217217
218218
"""
219219
if self.restarting and self._should_check_val_fx(data_fetcher):
220+
if self.val_loop.restarting_mid_evaluation:
221+
return
220222
# fast forward progress counters to end of validation
221223
self.val_loop.increment_progress_to_evaluation_end()
222224

tests/tests_pytorch/loops/test_loops.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,12 +397,13 @@ def training_step(self, batch, batch_idx):
397397
assert state_dict == checkpoint["loops"]["fit_loop"]
398398

399399
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
400-
# test resetting manually, we expect all `ready` counters to be reset to `completed`
400+
# test resetting manually, we expect the `ready` counter for batch to be reset to `completed`
401+
# but the `ready` counter for epoch to not be reset, since we are still mid epoch
401402
trainer.fit_loop.reset()
402403
trainer.fit_loop.epoch_loop.reset()
403404

404405
epoch_progress = trainer.fit_loop.epoch_progress
405-
assert epoch_progress.current.ready == stop_epoch
406+
assert epoch_progress.current.ready == stop_epoch + 1
406407
assert epoch_progress.current.completed == stop_epoch
407408

408409
batch_progress = trainer.fit_loop.epoch_loop.batch_progress
@@ -418,7 +419,7 @@ def training_step(self, batch, batch_idx):
418419
state_dict = trainer.fit_loop.state_dict()
419420
assert state_dict != checkpoint["loops"]["fit_loop"]
420421
assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1
421-
assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch
422+
assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch + 1
422423

423424

424425
def test_loop_state_on_complete_run(tmp_path):

0 commit comments

Comments
 (0)