From 703fa629b3fd8982108f461cc891f0195f47cc8c Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 30 Oct 2024 21:48:34 +0100 Subject: [PATCH 1/4] Ensure increment_completed is called prior to on batch_end hooks --- src/lightning/pytorch/CHANGELOG.md | 12 ++++++++++++ src/lightning/pytorch/loops/evaluation_loop.py | 4 ++-- src/lightning/pytorch/loops/prediction_loop.py | 4 ++-- .../pytorch/loops/training_epoch_loop.py | 6 ++++-- tests/tests_pytorch/loops/test_loops.py | 17 +++++++++-------- 5 files changed, 29 insertions(+), 14 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1b251b8fb06fa..01e3aec7dca78 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [2.5.0] - + +### Added + +### Changed + +- `batch_progress.increment_completed()` is now called right prior to `on_train_batch_end`, `on_validation_batch_end`, and `on_predict_batch_end` to ensure hooks see up-to-date progress information about the completed batch and avoid skews during restarts ([]()) + +### Removed + +### Fixed + ## [2.4.0] - 2024-08-06 diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 0ab3901cf072d..71cf13008a24d 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -410,10 +410,10 @@ def _evaluation_step( call._call_callback_hooks(trainer, hook_name, output, *hook_kwargs.values()) call._call_lightning_module_hook(trainer, hook_name, output, *hook_kwargs.values()) - trainer._logger_connector.on_batch_end() - self.batch_progress.increment_completed() + trainer._logger_connector.on_batch_end() + if not trainer.sanity_checking: # indicate the loop has run self._has_run = True diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index 9002e6280ffc6..b936b521d0b94 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -263,11 +263,11 @@ def _predict_step( dataloader_idx = data_fetcher._dataloader_idx hook_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx if self.num_dataloaders > 1 else None) + self.batch_progress.increment_completed() + call._call_callback_hooks(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values()) call._call_lightning_module_hook(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values()) - self.batch_progress.increment_completed() - if self._return_predictions or any_on_epoch: self._predictions[dataloader_idx].append(move_data_to_device(predictions, torch.device("cpu"))) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 9e36ee65176c8..890766897a828 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -266,12 +266,14 @@ def advance(self, data_fetcher: _DataFetcher) -> None: # update `is_last_batch` again after dataloader_iter was fetched in `training_step()` self.batch_progress.is_last_batch = data_fetcher.done + # we increment prior to on_batch_end so checkpoints can see progress correctly + # failure to do this will lead to incorrect restarts + self.batch_progress.increment_completed() + call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx) call._call_lightning_module_hook(trainer, "on_train_batch_end", batch_output, batch, batch_idx) trainer._logger_connector.on_batch_end() - self.batch_progress.increment_completed() - # ----------------------------------------- # SAVE METRICS TO LOGGERS AND PROGRESS_BAR # ----------------------------------------- diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index ff317cd2e18ba..91283f56dd9ba 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -570,10 +570,11 @@ def test_fit_loop_reset(tmp_path): assert epoch_loop.restarting assert epoch_loop.batch_progress.total.ready == 2 assert epoch_loop.batch_progress.total.processed == 2 - assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 1 # currents get set to the completed value - assert epoch_loop.batch_progress.current.processed == 1 - assert epoch_loop.batch_progress.current.completed == 1 + assert epoch_loop.batch_progress.total.completed == 2 # the checkpoint was saved on train_batch_end + # this used to be 1 but progress is now recorded before train_batch_end + assert epoch_loop.batch_progress.current.ready == 2 # currents get set to the completed value + assert epoch_loop.batch_progress.current.processed == 2 + assert epoch_loop.batch_progress.current.completed == 2 assert optimizer_loop.restarting @@ -600,10 +601,10 @@ def test_fit_loop_reset(tmp_path): assert epoch_loop.restarting assert epoch_loop.batch_progress.total.ready == 4 assert epoch_loop.batch_progress.total.processed == 4 - assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 3 # currents get set to the completed value - assert epoch_loop.batch_progress.current.processed == 3 - assert epoch_loop.batch_progress.current.completed == 3 + assert epoch_loop.batch_progress.total.completed == 4 # the checkpoint was saved on train_batch_end + assert epoch_loop.batch_progress.current.ready == 4 # currents get set to the completed value + assert epoch_loop.batch_progress.current.processed == 4 + assert epoch_loop.batch_progress.current.completed == 4 @pytest.mark.parametrize( From c2e46406a28c50f3b974c8655206f23f0a72c8e4 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 30 Oct 2024 21:50:14 +0100 Subject: [PATCH 2/4] Remove comment --- src/lightning/pytorch/CHANGELOG.md | 2 +- tests/tests_pytorch/loops/test_loops.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 01e3aec7dca78..4d146a524cdec 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -4,7 +4,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -## [2.5.0] - +## [2.5.0] - ### Added diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 91283f56dd9ba..9ec6c6e75e0bd 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -571,7 +571,6 @@ def test_fit_loop_reset(tmp_path): assert epoch_loop.batch_progress.total.ready == 2 assert epoch_loop.batch_progress.total.processed == 2 assert epoch_loop.batch_progress.total.completed == 2 # the checkpoint was saved on train_batch_end - # this used to be 1 but progress is now recorded before train_batch_end assert epoch_loop.batch_progress.current.ready == 2 # currents get set to the completed value assert epoch_loop.batch_progress.current.processed == 2 assert epoch_loop.batch_progress.current.completed == 2 From d6324b65a6e3fd72fe708c4a98d44cf6cdbcf0da Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 30 Oct 2024 21:59:34 +0100 Subject: [PATCH 3/4] Reword comment --- src/lightning/pytorch/loops/training_epoch_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 890766897a828..85d6c82af5614 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -267,7 +267,7 @@ def advance(self, data_fetcher: _DataFetcher) -> None: self.batch_progress.is_last_batch = data_fetcher.done # we increment prior to on_batch_end so checkpoints can see progress correctly - # failure to do this will lead to incorrect restarts + # failure to do this will lead to incorrect total batch completed progress counts self.batch_progress.increment_completed() call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx) From 136f9a7a84e74fc4e8d3209220a80e171232c982 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 30 Oct 2024 23:42:08 +0100 Subject: [PATCH 4/4] Make _batches_that_stepped consistent with progress --- src/lightning/pytorch/loops/training_epoch_loop.py | 7 ++++--- tests/tests_pytorch/loops/test_loops.py | 4 ++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 85d6c82af5614..deffac1adf270 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -270,6 +270,10 @@ def advance(self, data_fetcher: _DataFetcher) -> None: # failure to do this will lead to incorrect total batch completed progress counts self.batch_progress.increment_completed() + if not self._should_accumulate(): + # this is increased once per batch disregarding multiple optimizers on purpose for loggers + self._batches_that_stepped += 1 + call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx) call._call_lightning_module_hook(trainer, "on_train_batch_end", batch_output, batch, batch_idx) trainer._logger_connector.on_batch_end() @@ -301,9 +305,6 @@ def on_advance_end(self, data_fetcher: _DataFetcher) -> None: # update plateau LR scheduler after metrics are logged self.update_lr_schedulers("step", update_plateau_schedulers=True) - if not self._should_accumulate(): - # this is increased once per batch disregarding multiple optimizers on purpose for loggers - self._batches_that_stepped += 1 # this will save based on the `batches_that_stepped` value self._save_loggers_on_train_batch_end() diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 9ec6c6e75e0bd..97f4c6ebc7041 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -575,6 +575,8 @@ def test_fit_loop_reset(tmp_path): assert epoch_loop.batch_progress.current.processed == 2 assert epoch_loop.batch_progress.current.completed == 2 + assert epoch_loop._batches_that_stepped == 2 + assert optimizer_loop.restarting # reset state loaded from a checkpoint from the end of an epoch @@ -605,6 +607,8 @@ def test_fit_loop_reset(tmp_path): assert epoch_loop.batch_progress.current.processed == 4 assert epoch_loop.batch_progress.current.completed == 4 + assert epoch_loop._batches_that_stepped == 4 + @pytest.mark.parametrize( ("train_datasets", "val_datasets"),