diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1b251b8fb06fa..4d146a524cdec 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [2.5.0] - + +### Added + +### Changed + +- `batch_progress.increment_completed()` is now called right prior to `on_train_batch_end`, `on_validation_batch_end`, and `on_predict_batch_end` to ensure hooks see up-to-date progress information about the completed batch and avoid skews during restarts ([]()) + +### Removed + +### Fixed + ## [2.4.0] - 2024-08-06 diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 0ab3901cf072d..71cf13008a24d 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -410,10 +410,10 @@ def _evaluation_step( call._call_callback_hooks(trainer, hook_name, output, *hook_kwargs.values()) call._call_lightning_module_hook(trainer, hook_name, output, *hook_kwargs.values()) - trainer._logger_connector.on_batch_end() - self.batch_progress.increment_completed() + trainer._logger_connector.on_batch_end() + if not trainer.sanity_checking: # indicate the loop has run self._has_run = True diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index 9002e6280ffc6..b936b521d0b94 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -263,11 +263,11 @@ def _predict_step( dataloader_idx = data_fetcher._dataloader_idx hook_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx if self.num_dataloaders > 1 else None) + self.batch_progress.increment_completed() + call._call_callback_hooks(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values()) call._call_lightning_module_hook(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values()) - self.batch_progress.increment_completed() - if self._return_predictions or any_on_epoch: self._predictions[dataloader_idx].append(move_data_to_device(predictions, torch.device("cpu"))) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 9e36ee65176c8..deffac1adf270 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -266,12 +266,18 @@ def advance(self, data_fetcher: _DataFetcher) -> None: # update `is_last_batch` again after dataloader_iter was fetched in `training_step()` self.batch_progress.is_last_batch = data_fetcher.done + # we increment prior to on_batch_end so checkpoints can see progress correctly + # failure to do this will lead to incorrect total batch completed progress counts + self.batch_progress.increment_completed() + + if not self._should_accumulate(): + # this is increased once per batch disregarding multiple optimizers on purpose for loggers + self._batches_that_stepped += 1 + call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx) call._call_lightning_module_hook(trainer, "on_train_batch_end", batch_output, batch, batch_idx) trainer._logger_connector.on_batch_end() - self.batch_progress.increment_completed() - # ----------------------------------------- # SAVE METRICS TO LOGGERS AND PROGRESS_BAR # ----------------------------------------- @@ -299,9 +305,6 @@ def on_advance_end(self, data_fetcher: _DataFetcher) -> None: # update plateau LR scheduler after metrics are logged self.update_lr_schedulers("step", update_plateau_schedulers=True) - if not self._should_accumulate(): - # this is increased once per batch disregarding multiple optimizers on purpose for loggers - self._batches_that_stepped += 1 # this will save based on the `batches_that_stepped` value self._save_loggers_on_train_batch_end() diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index ff317cd2e18ba..97f4c6ebc7041 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -570,10 +570,12 @@ def test_fit_loop_reset(tmp_path): assert epoch_loop.restarting assert epoch_loop.batch_progress.total.ready == 2 assert epoch_loop.batch_progress.total.processed == 2 - assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 1 # currents get set to the completed value - assert epoch_loop.batch_progress.current.processed == 1 - assert epoch_loop.batch_progress.current.completed == 1 + assert epoch_loop.batch_progress.total.completed == 2 # the checkpoint was saved on train_batch_end + assert epoch_loop.batch_progress.current.ready == 2 # currents get set to the completed value + assert epoch_loop.batch_progress.current.processed == 2 + assert epoch_loop.batch_progress.current.completed == 2 + + assert epoch_loop._batches_that_stepped == 2 assert optimizer_loop.restarting @@ -600,10 +602,12 @@ def test_fit_loop_reset(tmp_path): assert epoch_loop.restarting assert epoch_loop.batch_progress.total.ready == 4 assert epoch_loop.batch_progress.total.processed == 4 - assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 3 # currents get set to the completed value - assert epoch_loop.batch_progress.current.processed == 3 - assert epoch_loop.batch_progress.current.completed == 3 + assert epoch_loop.batch_progress.total.completed == 4 # the checkpoint was saved on train_batch_end + assert epoch_loop.batch_progress.current.ready == 4 # currents get set to the completed value + assert epoch_loop.batch_progress.current.processed == 4 + assert epoch_loop.batch_progress.current.completed == 4 + + assert epoch_loop._batches_that_stepped == 4 @pytest.mark.parametrize(