diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 1db30fb489b47..8de1112b58f7c 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -28,6 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed `EADDRINUSE` errors in distributed tests with port manager and retry logic ([#21309](https://github.com/Lightning-AI/pytorch-lightning/pull/21309)) +- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)). --- diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index 0b0ab14244e38..ea95e2c4573a6 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -69,6 +69,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional[int]: """Called in the training loop before anything happens for that batch. If you return -1 here, you will skip training for the rest of the current epoch. + Learning rate scheduler will still be stepped at the end of epoch. Args: batch: The batched data as it is returned by the training DataLoader. diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 3d01780b705fe..6212bfe264e6e 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -325,6 +325,8 @@ def advance(self, data_fetcher: _DataFetcher) -> None: trainer._logger_connector.on_batch_start(batch) batch_output: _BATCH_OUTPUTS_TYPE = None # for mypy + should_skip_rest_of_epoch = False + if batch is None and not using_dataloader_iter: self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") else: @@ -332,23 +334,24 @@ def advance(self, data_fetcher: _DataFetcher) -> None: call._call_callback_hooks(trainer, "on_train_batch_start", batch, batch_idx) response = call._call_lightning_module_hook(trainer, "on_train_batch_start", batch, batch_idx) call._call_strategy_hook(trainer, "on_train_batch_start", batch, batch_idx) - if response == -1: - self.batch_progress.increment_processed() - raise StopIteration - - self.batch_progress.increment_started() - - kwargs = ( - self._build_kwargs(OrderedDict(), batch, batch_idx) - if not using_dataloader_iter - else OrderedDict(any=dataloader_iter) - ) - with trainer.profiler.profile("run_training_batch"): - if trainer.lightning_module.automatic_optimization: - # in automatic optimization, there can only be one optimizer - batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs) - else: - batch_output = self.manual_optimization.run(kwargs) + should_skip_rest_of_epoch = response == -1 + # Signal this is the last batch for the current epoch + if should_skip_rest_of_epoch: + self.batch_progress.increment_by(0, is_last_batch=True) + else: + self.batch_progress.increment_started() + + kwargs = ( + self._build_kwargs(OrderedDict(), batch, batch_idx) + if not using_dataloader_iter + else OrderedDict(any=dataloader_iter) + ) + with trainer.profiler.profile("run_training_batch"): + if trainer.lightning_module.automatic_optimization: + # in automatic optimization, there can only be one optimizer + batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs) + else: + batch_output = self.manual_optimization.run(kwargs) self.batch_progress.increment_processed() @@ -358,6 +361,10 @@ def advance(self, data_fetcher: _DataFetcher) -> None: if self._num_ready_batches_reached(): self.update_lr_schedulers("epoch", update_plateau_schedulers=False) + if should_skip_rest_of_epoch: + # Only raise StopIteration now so that the training epoch loop can finish + raise StopIteration + if using_dataloader_iter: # update the hook kwargs now that the step method might have consumed the iterator batch = data_fetcher._batch diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index f5aaa18095fc5..9b6809e3749e3 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -111,6 +111,8 @@ def on_train_batch_start(self, batch, batch_idx): assert trainer.fit_loop.batch_idx == batch_idx_ assert trainer.global_step == batch_idx_ * max_epochs + assert trainer.is_last_batch + def test_should_stop_mid_epoch(tmp_path): """Test that training correctly stops mid epoch and that validation is still called at the right time.""" @@ -305,3 +307,26 @@ def test_eval_mode_warning(tmp_path, warn): w for w in warning_list if issubclass(w.category, PossibleUserWarning) and "eval mode" in str(w.message) ] assert len(eval_warnings) == 0, "Expected no eval mode warnings" + + +@pytest.mark.parametrize(("max_epochs", "batch_idx_"), [(2, 5), (3, 8)]) +def test_lr_updated_on_train_batch_start_returns_minus_one(tmp_path, max_epochs, batch_idx_): + """Test that when the rest of the epoch is skipped, due to on_train_batch_start returning -1, the learning rate is + still updated when it should, at the end of the epoch.""" + + class TestModel(BoringModel): + def on_train_batch_start(self, batch, batch_idx): + if batch_idx == batch_idx_: + return -1 + return super().on_train_batch_start(batch, batch_idx) + + model = TestModel() + init_lr = 0.1 + trainer = Trainer(default_root_dir=tmp_path, limit_train_batches=10, max_epochs=max_epochs) + trainer.fit(model) + + adjusted_lr = [pg["lr"] for pg in trainer.optimizers[0].param_groups] + + assert len(trainer.lr_scheduler_configs) == 1 + assert all(a == adjusted_lr[0] for a in adjusted_lr) + assert init_lr * 0.1**max_epochs == adjusted_lr[0]