Skip to content

Commit 703fa62

Browse files
committed
Ensure increment_completed is called prior to on batch_end hooks
1 parent 06a8d5b commit 703fa62

File tree

5 files changed

+29
-14
lines changed

5 files changed

+29
-14
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7+
## [2.5.0] -
8+
9+
### Added
10+
11+
### Changed
12+
13+
- `batch_progress.increment_completed()` is now called right prior to `on_train_batch_end`, `on_validation_batch_end`, and `on_predict_batch_end` to ensure hooks see up-to-date progress information about the completed batch and avoid skews during restarts ([]())
14+
15+
### Removed
16+
17+
### Fixed
18+
719

820
## [2.4.0] - 2024-08-06
921

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,10 +410,10 @@ def _evaluation_step(
410410
call._call_callback_hooks(trainer, hook_name, output, *hook_kwargs.values())
411411
call._call_lightning_module_hook(trainer, hook_name, output, *hook_kwargs.values())
412412

413-
trainer._logger_connector.on_batch_end()
414-
415413
self.batch_progress.increment_completed()
416414

415+
trainer._logger_connector.on_batch_end()
416+
417417
if not trainer.sanity_checking:
418418
# indicate the loop has run
419419
self._has_run = True

src/lightning/pytorch/loops/prediction_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,11 +263,11 @@ def _predict_step(
263263
dataloader_idx = data_fetcher._dataloader_idx
264264
hook_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx if self.num_dataloaders > 1 else None)
265265

266+
self.batch_progress.increment_completed()
267+
266268
call._call_callback_hooks(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values())
267269
call._call_lightning_module_hook(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values())
268270

269-
self.batch_progress.increment_completed()
270-
271271
if self._return_predictions or any_on_epoch:
272272
self._predictions[dataloader_idx].append(move_data_to_device(predictions, torch.device("cpu")))
273273

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,14 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
266266
# update `is_last_batch` again after dataloader_iter was fetched in `training_step()`
267267
self.batch_progress.is_last_batch = data_fetcher.done
268268

269+
# we increment prior to on_batch_end so checkpoints can see progress correctly
270+
# failure to do this will lead to incorrect restarts
271+
self.batch_progress.increment_completed()
272+
269273
call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
270274
call._call_lightning_module_hook(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
271275
trainer._logger_connector.on_batch_end()
272276

273-
self.batch_progress.increment_completed()
274-
275277
# -----------------------------------------
276278
# SAVE METRICS TO LOGGERS AND PROGRESS_BAR
277279
# -----------------------------------------

tests/tests_pytorch/loops/test_loops.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -570,10 +570,11 @@ def test_fit_loop_reset(tmp_path):
570570
assert epoch_loop.restarting
571571
assert epoch_loop.batch_progress.total.ready == 2
572572
assert epoch_loop.batch_progress.total.processed == 2
573-
assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end
574-
assert epoch_loop.batch_progress.current.ready == 1 # currents get set to the completed value
575-
assert epoch_loop.batch_progress.current.processed == 1
576-
assert epoch_loop.batch_progress.current.completed == 1
573+
assert epoch_loop.batch_progress.total.completed == 2 # the checkpoint was saved on train_batch_end
574+
# this used to be 1 but progress is now recorded before train_batch_end
575+
assert epoch_loop.batch_progress.current.ready == 2 # currents get set to the completed value
576+
assert epoch_loop.batch_progress.current.processed == 2
577+
assert epoch_loop.batch_progress.current.completed == 2
577578

578579
assert optimizer_loop.restarting
579580

@@ -600,10 +601,10 @@ def test_fit_loop_reset(tmp_path):
600601
assert epoch_loop.restarting
601602
assert epoch_loop.batch_progress.total.ready == 4
602603
assert epoch_loop.batch_progress.total.processed == 4
603-
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
604-
assert epoch_loop.batch_progress.current.ready == 3 # currents get set to the completed value
605-
assert epoch_loop.batch_progress.current.processed == 3
606-
assert epoch_loop.batch_progress.current.completed == 3
604+
assert epoch_loop.batch_progress.total.completed == 4 # the checkpoint was saved on train_batch_end
605+
assert epoch_loop.batch_progress.current.ready == 4 # currents get set to the completed value
606+
assert epoch_loop.batch_progress.current.processed == 4
607+
assert epoch_loop.batch_progress.current.completed == 4
607608

608609

609610
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)