From adb878079d83b2311e6e02f1afc3158b62cfe3e7 Mon Sep 17 00:00:00 2001 From: Lucas Date: Mon, 20 Oct 2025 10:41:32 +0200 Subject: [PATCH 1/7] fix: Update lr if train_batch_start returns -1 --- .../pytorch/loops/training_epoch_loop.py | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 3d01780b705fe..7b6283f437845 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -325,6 +325,8 @@ def advance(self, data_fetcher: _DataFetcher) -> None: trainer._logger_connector.on_batch_start(batch) batch_output: _BATCH_OUTPUTS_TYPE = None # for mypy + should_skip_training = False + if batch is None and not using_dataloader_iter: self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") else: @@ -332,23 +334,24 @@ def advance(self, data_fetcher: _DataFetcher) -> None: call._call_callback_hooks(trainer, "on_train_batch_start", batch, batch_idx) response = call._call_lightning_module_hook(trainer, "on_train_batch_start", batch, batch_idx) call._call_strategy_hook(trainer, "on_train_batch_start", batch, batch_idx) - if response == -1: - self.batch_progress.increment_processed() - raise StopIteration + should_skip_training = response == -1 + # Signal this is the last batch for the current epoch + self.batch_progress.increment_by(0, is_last_batch=True) - self.batch_progress.increment_started() + if not should_skip_training: + self.batch_progress.increment_started() - kwargs = ( - self._build_kwargs(OrderedDict(), batch, batch_idx) - if not using_dataloader_iter - else OrderedDict(any=dataloader_iter) - ) - with trainer.profiler.profile("run_training_batch"): - if trainer.lightning_module.automatic_optimization: - # in automatic optimization, there can only be one optimizer - batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs) - else: - batch_output = self.manual_optimization.run(kwargs) + kwargs = ( + self._build_kwargs(OrderedDict(), batch, batch_idx) + if not using_dataloader_iter + else OrderedDict(any=dataloader_iter) + ) + with trainer.profiler.profile("run_training_batch"): + if trainer.lightning_module.automatic_optimization: + # in automatic optimization, there can only be one optimizer + batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs) + else: + batch_output = self.manual_optimization.run(kwargs) self.batch_progress.increment_processed() @@ -358,6 +361,10 @@ def advance(self, data_fetcher: _DataFetcher) -> None: if self._num_ready_batches_reached(): self.update_lr_schedulers("epoch", update_plateau_schedulers=False) + if should_skip_training: + # Only raise StopIteration now so that the training epoch loop can finish + raise StopIteration + if using_dataloader_iter: # update the hook kwargs now that the step method might have consumed the iterator batch = data_fetcher._batch From 4fa5c05d72dfae387da50bbbd19f55154a94ec07 Mon Sep 17 00:00:00 2001 From: Lucas Date: Mon, 20 Oct 2025 10:43:27 +0200 Subject: [PATCH 2/7] fix: Rename variable: should_skip_rest_of_epoch --- src/lightning/pytorch/loops/training_epoch_loop.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 7b6283f437845..90eb4e46859e9 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -325,7 +325,7 @@ def advance(self, data_fetcher: _DataFetcher) -> None: trainer._logger_connector.on_batch_start(batch) batch_output: _BATCH_OUTPUTS_TYPE = None # for mypy - should_skip_training = False + should_skip_rest_of_epoch = False if batch is None and not using_dataloader_iter: self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") @@ -334,11 +334,11 @@ def advance(self, data_fetcher: _DataFetcher) -> None: call._call_callback_hooks(trainer, "on_train_batch_start", batch, batch_idx) response = call._call_lightning_module_hook(trainer, "on_train_batch_start", batch, batch_idx) call._call_strategy_hook(trainer, "on_train_batch_start", batch, batch_idx) - should_skip_training = response == -1 + should_skip_rest_of_epoch = response == -1 # Signal this is the last batch for the current epoch self.batch_progress.increment_by(0, is_last_batch=True) - if not should_skip_training: + if not should_skip_rest_of_epoch: self.batch_progress.increment_started() kwargs = ( @@ -361,7 +361,7 @@ def advance(self, data_fetcher: _DataFetcher) -> None: if self._num_ready_batches_reached(): self.update_lr_schedulers("epoch", update_plateau_schedulers=False) - if should_skip_training: + if should_skip_rest_of_epoch: # Only raise StopIteration now so that the training epoch loop can finish raise StopIteration From 4a713e559fd7a03edbdef86e2200ba1a441248c6 Mon Sep 17 00:00:00 2001 From: Lucas Date: Mon, 20 Oct 2025 12:23:22 +0200 Subject: [PATCH 3/7] chore: Update changelog --- src/lightning/fabric/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 0e1cc944a3492..94fc998f1684a 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -27,7 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([[#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)]). --- From 71d2ebd13b4e3332381a9bd9ee558ebcfe1fe09a Mon Sep 17 00:00:00 2001 From: Lucas Date: Mon, 20 Oct 2025 14:13:52 +0200 Subject: [PATCH 4/7] fix: Batch increment --- src/lightning/pytorch/loops/training_epoch_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 90eb4e46859e9..6212bfe264e6e 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -336,9 +336,9 @@ def advance(self, data_fetcher: _DataFetcher) -> None: call._call_strategy_hook(trainer, "on_train_batch_start", batch, batch_idx) should_skip_rest_of_epoch = response == -1 # Signal this is the last batch for the current epoch - self.batch_progress.increment_by(0, is_last_batch=True) - - if not should_skip_rest_of_epoch: + if should_skip_rest_of_epoch: + self.batch_progress.increment_by(0, is_last_batch=True) + else: self.batch_progress.increment_started() kwargs = ( From 60a7cd3ef457974a06d420251e2f01775f811bd8 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 23 Oct 2025 06:39:09 +0200 Subject: [PATCH 5/7] Apply suggestion from @SkafteNicki --- src/lightning/fabric/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 94fc998f1684a..7233d3c93f688 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -27,7 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([[#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)]). +- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)). --- From 10d18d91c4c750a180f7be4b10b566e2e2ec097c Mon Sep 17 00:00:00 2001 From: Lucas Date: Thu, 23 Oct 2025 13:52:36 +0200 Subject: [PATCH 6/7] test: Check lr is updated at the end of epoch When `on_train_batch_start` returns -1, the rest of the epoch is skipped. The lr update should still happen at the end of the epoch. - Test is_last_batch has been set correctly - Test lr has been updated at the end of each epoch --- .../tests_pytorch/loops/test_training_loop.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index f5aaa18095fc5..9b6809e3749e3 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -111,6 +111,8 @@ def on_train_batch_start(self, batch, batch_idx): assert trainer.fit_loop.batch_idx == batch_idx_ assert trainer.global_step == batch_idx_ * max_epochs + assert trainer.is_last_batch + def test_should_stop_mid_epoch(tmp_path): """Test that training correctly stops mid epoch and that validation is still called at the right time.""" @@ -305,3 +307,26 @@ def test_eval_mode_warning(tmp_path, warn): w for w in warning_list if issubclass(w.category, PossibleUserWarning) and "eval mode" in str(w.message) ] assert len(eval_warnings) == 0, "Expected no eval mode warnings" + + +@pytest.mark.parametrize(("max_epochs", "batch_idx_"), [(2, 5), (3, 8)]) +def test_lr_updated_on_train_batch_start_returns_minus_one(tmp_path, max_epochs, batch_idx_): + """Test that when the rest of the epoch is skipped, due to on_train_batch_start returning -1, the learning rate is + still updated when it should, at the end of the epoch.""" + + class TestModel(BoringModel): + def on_train_batch_start(self, batch, batch_idx): + if batch_idx == batch_idx_: + return -1 + return super().on_train_batch_start(batch, batch_idx) + + model = TestModel() + init_lr = 0.1 + trainer = Trainer(default_root_dir=tmp_path, limit_train_batches=10, max_epochs=max_epochs) + trainer.fit(model) + + adjusted_lr = [pg["lr"] for pg in trainer.optimizers[0].param_groups] + + assert len(trainer.lr_scheduler_configs) == 1 + assert all(a == adjusted_lr[0] for a in adjusted_lr) + assert init_lr * 0.1**max_epochs == adjusted_lr[0] From e115b842df3caa5945f9bf2d7a77da9c227b0d2f Mon Sep 17 00:00:00 2001 From: Lucas Date: Mon, 27 Oct 2025 09:29:29 +0100 Subject: [PATCH 7/7] doc: Add documentation for lr update --- src/lightning/pytorch/core/hooks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index 0b0ab14244e38..ea95e2c4573a6 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -69,6 +69,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional[int]: """Called in the training loop before anything happens for that batch. If you return -1 here, you will skip training for the rest of the current epoch. + Learning rate scheduler will still be stepped at the end of epoch. Args: batch: The batched data as it is returned by the training DataLoader.