Skip to content

Commit 524a3e7

Browse files
littlebullGitBorda
authored andcommitted
Fix LR not being correctly set after using LearningRateFinder callback (#21068)
* fix(tuner/lr_finder): apply LR suggestion after checkpoint restore when used as callback Previously, LearningRateFinder applied the suggested LR before restoring the checkpoint, so the optimizer LR was reverted by the restore step. This caused the callback to print “Learning rate set to …” without persisting the change. Change: - Move LR application to after checkpoint restore and update both the LM attr and active optimizer param groups so the LR persists for training. Tests: - Add unit test [test_lr_finder_callback_applies_lr_after_restore] to assert the optimizer LR matches the LR Finder suggestion after the search completes. * changelog * Apply suggestions from code review --------- Co-authored-by: Nicki Skafte Detlefsen <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 3ed9d4e)
1 parent 39d1a9e commit 524a3e7

File tree

3 files changed

+116
-13
lines changed

3 files changed

+116
-13
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,33 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77
---
88

9-
## [2.5.3] - 2025-08-DD
9+
## [unreleased] - YYYY-MM-DD
10+
11+
### Added
12+
13+
-
14+
15+
16+
### Changed
17+
18+
-
19+
20+
21+
### Removed
22+
23+
-
24+
25+
26+
### Fixed
27+
28+
-
29+
30+
31+
- Fixed learning rate not being correctly set after using `LearningRateFinder` callback ([#21068](https://github.com/Lightning-AI/pytorch-lightning/pull/21068))
32+
33+
---
34+
35+
## [2.5.3] - 2025-08-13
1036

1137
### Changed
1238

@@ -57,14 +83,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5783

5884
- Allow LightningCLI to use a customized argument parser class ([#20596](https://github.com/Lightning-AI/pytorch-lightning/pull/20596))
5985
- Change `wandb` default x-axis to `tensorboard`'s `global_step` when `sync_tensorboard=True` ([#20611](https://github.com/Lightning-AI/pytorch-lightning/pull/20611))
60-
- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored ([#20538](https://github.com/Lightning-AI/pytorch-lightning/pull/20538))
6186
- CometML logger was updated to support the recent Comet SDK ([#20275](https://github.com/Lightning-AI/pytorch-lightning/pull/20275))
6287
- bump: testing with latest `torch` 2.6 ([#20509](https://github.com/Lightning-AI/pytorch-lightning/pull/20509))
6388

6489
### Fixed
6590

66-
- Fixed `CSVLogger` logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594))
67-
- Fixed `OverflowError` when resuming from checkpoint with an iterable dataset ([#20565](https://github.com/Lightning-AI/pytorch-lightning/issues/20565))
91+
- Fixed CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594))
92+
- Fixed OverflowError when resuming from checkpoint with an iterable dataset ([#20565](https://github.com/Lightning-AI/pytorch-lightning/issues/20565))
6893
- Fixed swapped _R_co and _P to prevent type error ([#20508](https://github.com/Lightning-AI/pytorch-lightning/issues/20508))
6994
- Always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610))
7095
- Fixed TBPTT example ([#20528](https://github.com/Lightning-AI/pytorch-lightning/pull/20528))

src/lightning/pytorch/tuner/lr_finder.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,24 +276,30 @@ def _lr_find(
276276
if trainer.progress_bar_callback:
277277
trainer.progress_bar_callback.enable()
278278

279-
# Update lr attr if required
279+
# Update results across ranks
280280
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
281-
if update_attr:
282-
lr = lr_finder.suggestion()
283-
284-
# TODO: log lr.results to self.logger
285-
if lr is not None:
286-
lightning_setattr(model, attr_name, lr)
287-
log.info(f"Learning rate set to {lr}")
288281

289-
# Restore initial state of model
282+
# Restore initial state of model (this will also restore the original optimizer state)
290283
trainer._checkpoint_connector.restore(ckpt_path)
291284
trainer.strategy.remove_checkpoint(ckpt_path)
292285
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
293286
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
294287
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
295288
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
296289
trainer.fit_loop.setup_data()
290+
291+
# Apply LR suggestion after restoring so it persists for the real training run
292+
# When used as a callback, the suggestion would otherwise be lost due to checkpoint restore
293+
if update_attr:
294+
lr = lr_finder.suggestion()
295+
if lr is not None:
296+
# update the attribute on the LightningModule (e.g., lr or learning_rate)
297+
lightning_setattr(model, attr_name, lr)
298+
# also update the currently active optimizer(s) so training continues with the suggested LR
299+
for opt in trainer.optimizers or []:
300+
for pg in opt.param_groups:
301+
pg["lr"] = lr
302+
log.info(f"Learning rate set to {lr}")
297303
return lr_finder
298304

299305

tests/tests_pytorch/tuner/test_lr_finder.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,78 @@ def test_gradient_correctness():
619619
assert abs(suggestion - math.pi) < 1e-2, "Suggestion should be close to pi for this synthetic example"
620620

621621

622+
def test_lr_finder_callback_applies_lr_after_restore(tmp_path):
623+
"""LearningRateFinder used as a callback should apply its suggested LR to the optimizer used after state
624+
restoration."""
625+
626+
import torch.nn as nn
627+
import torch.nn.functional as F
628+
from torch.utils.data import DataLoader, Dataset
629+
630+
from lightning.pytorch.callbacks import LearningRateMonitor
631+
632+
class RandomDataset(Dataset):
633+
def __init__(self, n: int = 256, in_dim: int = 28 * 28):
634+
self.x = torch.randn(n, in_dim)
635+
self.y = torch.randn(n, in_dim)
636+
637+
def __len__(self) -> int:
638+
return len(self.x)
639+
640+
def __getitem__(self, idx):
641+
return self.x[idx], self.y[idx]
642+
643+
class TinyAE(BoringModel):
644+
def __init__(self, lr: float = 1e-5):
645+
super().__init__()
646+
self.save_hyperparameters()
647+
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
648+
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
649+
650+
def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
651+
x, y = batch
652+
z = self.encoder(x)
653+
x_hat = self.decoder(z)
654+
loss = F.mse_loss(x_hat, y)
655+
return loss
656+
657+
def configure_optimizers(self):
658+
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
659+
660+
seed_everything(123)
661+
662+
ds = RandomDataset(n=512)
663+
train_loader = DataLoader(ds, batch_size=64, shuffle=False)
664+
665+
model = TinyAE(lr=1e-5)
666+
667+
lr_finder_cb = LearningRateFinder() # default update_attr=True should apply suggestion
668+
lr_monitor = LearningRateMonitor(logging_interval="step")
669+
670+
trainer = Trainer(
671+
default_root_dir=tmp_path,
672+
max_epochs=2,
673+
callbacks=[lr_finder_cb, lr_monitor],
674+
enable_model_summary=False,
675+
enable_progress_bar=False,
676+
log_every_n_steps=1,
677+
)
678+
679+
trainer.fit(model, train_loader)
680+
assert model.hparams.lr is not None
681+
# Ensure LR Finder produced a suggestion for this setup; if not, the test can't assert application
682+
assert lr_finder_cb.optimal_lr is not None, "LR Finder should have computed results"
683+
suggestion = lr_finder_cb.optimal_lr.suggestion()
684+
assert suggestion is not None, "LR Finder should produce a suggestion for this setup"
685+
686+
# Verify that the optimizer used for subsequent training has the suggested LR applied
687+
assert trainer.optimizers, "Trainer should have an optimizer after fit"
688+
current_lr = trainer.optimizers[0].param_groups[0]["lr"]
689+
assert current_lr == pytest.approx(suggestion), (
690+
f"LR Finder suggestion {suggestion} should be applied to optimizer, but got {current_lr}"
691+
)
692+
693+
622694
def test_exponential_vs_linear_mode_gradient_difference(tmp_path):
623695
"""Test that exponential and linear modes produce different but valid suggestions.
624696

0 commit comments

Comments
 (0)