Skip to content

Commit 136f9a7

Browse files
committed
Make _batches_that_stepped consistent with progress
1 parent d6324b6 commit 136f9a7

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,10 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
270270
# failure to do this will lead to incorrect total batch completed progress counts
271271
self.batch_progress.increment_completed()
272272

273+
if not self._should_accumulate():
274+
# this is increased once per batch disregarding multiple optimizers on purpose for loggers
275+
self._batches_that_stepped += 1
276+
273277
call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
274278
call._call_lightning_module_hook(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
275279
trainer._logger_connector.on_batch_end()
@@ -301,9 +305,6 @@ def on_advance_end(self, data_fetcher: _DataFetcher) -> None:
301305
# update plateau LR scheduler after metrics are logged
302306
self.update_lr_schedulers("step", update_plateau_schedulers=True)
303307

304-
if not self._should_accumulate():
305-
# this is increased once per batch disregarding multiple optimizers on purpose for loggers
306-
self._batches_that_stepped += 1
307308
# this will save based on the `batches_that_stepped` value
308309
self._save_loggers_on_train_batch_end()
309310

tests/tests_pytorch/loops/test_loops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,8 @@ def test_fit_loop_reset(tmp_path):
575575
assert epoch_loop.batch_progress.current.processed == 2
576576
assert epoch_loop.batch_progress.current.completed == 2
577577

578+
assert epoch_loop._batches_that_stepped == 2
579+
578580
assert optimizer_loop.restarting
579581

580582
# reset state loaded from a checkpoint from the end of an epoch
@@ -605,6 +607,8 @@ def test_fit_loop_reset(tmp_path):
605607
assert epoch_loop.batch_progress.current.processed == 4
606608
assert epoch_loop.batch_progress.current.completed == 4
607609

610+
assert epoch_loop._batches_that_stepped == 4
611+
608612

609613
@pytest.mark.parametrize(
610614
("train_datasets", "val_datasets"),

0 commit comments

Comments
 (0)