Skip to content

Commit f7692a6

Browse files
LTMeyerSkafteNicki
andauthored
LRs updates are called at the end of a skipped epoch (#21307)
* fix: Update lr if train_batch_start returns -1 * fix: Rename variable: should_skip_rest_of_epoch * chore: Update changelog * fix: Batch increment * Apply suggestion from @SkafteNicki * test: Check lr is updated at the end of epoch When `on_train_batch_start` returns -1, the rest of the epoch is skipped. The lr update should still happen at the end of the epoch. - Test is_last_batch has been set correctly - Test lr has been updated at the end of each epoch * doc: Add documentation for lr update --------- Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
1 parent 8f1c1ac commit f7692a6

File tree

4 files changed

+55
-18
lines changed

4 files changed

+55
-18
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525

2626
### Fixed
2727

28-
-
28+
- Fixed `EADDRINUSE` errors in distributed tests with port manager and retry logic ([#21309](https://github.com/Lightning-AI/pytorch-lightning/pull/21309))
29+
30+
31+
- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)).
32+
2933

3034

3135
---

src/lightning/pytorch/core/hooks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional[int]:
6969
"""Called in the training loop before anything happens for that batch.
7070
7171
If you return -1 here, you will skip training for the rest of the current epoch.
72+
Learning rate scheduler will still be stepped at the end of epoch.
7273
7374
Args:
7475
batch: The batched data as it is returned by the training DataLoader.

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -325,30 +325,33 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
325325
trainer._logger_connector.on_batch_start(batch)
326326

327327
batch_output: _BATCH_OUTPUTS_TYPE = None # for mypy
328+
should_skip_rest_of_epoch = False
329+
328330
if batch is None and not using_dataloader_iter:
329331
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
330332
else:
331333
# hook
332334
call._call_callback_hooks(trainer, "on_train_batch_start", batch, batch_idx)
333335
response = call._call_lightning_module_hook(trainer, "on_train_batch_start", batch, batch_idx)
334336
call._call_strategy_hook(trainer, "on_train_batch_start", batch, batch_idx)
335-
if response == -1:
336-
self.batch_progress.increment_processed()
337-
raise StopIteration
338-
339-
self.batch_progress.increment_started()
340-
341-
kwargs = (
342-
self._build_kwargs(OrderedDict(), batch, batch_idx)
343-
if not using_dataloader_iter
344-
else OrderedDict(any=dataloader_iter)
345-
)
346-
with trainer.profiler.profile("run_training_batch"):
347-
if trainer.lightning_module.automatic_optimization:
348-
# in automatic optimization, there can only be one optimizer
349-
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
350-
else:
351-
batch_output = self.manual_optimization.run(kwargs)
337+
should_skip_rest_of_epoch = response == -1
338+
# Signal this is the last batch for the current epoch
339+
if should_skip_rest_of_epoch:
340+
self.batch_progress.increment_by(0, is_last_batch=True)
341+
else:
342+
self.batch_progress.increment_started()
343+
344+
kwargs = (
345+
self._build_kwargs(OrderedDict(), batch, batch_idx)
346+
if not using_dataloader_iter
347+
else OrderedDict(any=dataloader_iter)
348+
)
349+
with trainer.profiler.profile("run_training_batch"):
350+
if trainer.lightning_module.automatic_optimization:
351+
# in automatic optimization, there can only be one optimizer
352+
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
353+
else:
354+
batch_output = self.manual_optimization.run(kwargs)
352355

353356
self.batch_progress.increment_processed()
354357

@@ -358,6 +361,10 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
358361
if self._num_ready_batches_reached():
359362
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)
360363

364+
if should_skip_rest_of_epoch:
365+
# Only raise StopIteration now so that the training epoch loop can finish
366+
raise StopIteration
367+
361368
if using_dataloader_iter:
362369
# update the hook kwargs now that the step method might have consumed the iterator
363370
batch = data_fetcher._batch

tests/tests_pytorch/loops/test_training_loop.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ def on_train_batch_start(self, batch, batch_idx):
111111
assert trainer.fit_loop.batch_idx == batch_idx_
112112
assert trainer.global_step == batch_idx_ * max_epochs
113113

114+
assert trainer.is_last_batch
115+
114116

115117
def test_should_stop_mid_epoch(tmp_path):
116118
"""Test that training correctly stops mid epoch and that validation is still called at the right time."""
@@ -305,3 +307,26 @@ def test_eval_mode_warning(tmp_path, warn):
305307
w for w in warning_list if issubclass(w.category, PossibleUserWarning) and "eval mode" in str(w.message)
306308
]
307309
assert len(eval_warnings) == 0, "Expected no eval mode warnings"
310+
311+
312+
@pytest.mark.parametrize(("max_epochs", "batch_idx_"), [(2, 5), (3, 8)])
313+
def test_lr_updated_on_train_batch_start_returns_minus_one(tmp_path, max_epochs, batch_idx_):
314+
"""Test that when the rest of the epoch is skipped, due to on_train_batch_start returning -1, the learning rate is
315+
still updated when it should, at the end of the epoch."""
316+
317+
class TestModel(BoringModel):
318+
def on_train_batch_start(self, batch, batch_idx):
319+
if batch_idx == batch_idx_:
320+
return -1
321+
return super().on_train_batch_start(batch, batch_idx)
322+
323+
model = TestModel()
324+
init_lr = 0.1
325+
trainer = Trainer(default_root_dir=tmp_path, limit_train_batches=10, max_epochs=max_epochs)
326+
trainer.fit(model)
327+
328+
adjusted_lr = [pg["lr"] for pg in trainer.optimizers[0].param_groups]
329+
330+
assert len(trainer.lr_scheduler_configs) == 1
331+
assert all(a == adjusted_lr[0] for a in adjusted_lr)
332+
assert init_lr * 0.1**max_epochs == adjusted_lr[0]

0 commit comments

Comments
 (0)