Skip to content

Commit 7e22306

Browse files
committed
Handle end of batch
1 parent a5885ae commit 7e22306

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ def reset(self) -> None:
150150
# when loop state dict is saved on_train_batch_end, that is, before increment completed is called
151151
if self.batch_progress.total.completed < self.batch_progress.total.processed:
152152
self.batch_progress.increment_completed()
153+
# handle situation in which save happened on_train_batch_end and epoch is at end
154+
if self.batch_progress.current.completed >= self.trainer.num_training_batches:
155+
self.batch_progress.reset_on_run()
156+
self.scheduler_progress.reset_on_run()
157+
self.automatic_optimization.optim_progress.reset_on_run()
158+
self.val_loop.batch_progress.total.reset()
153159
if not self._should_accumulate():
154160
self._batches_that_stepped += 1
155161

0 commit comments

Comments
 (0)