diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 64e756801008a..8ea1402330094 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - +- Fixed learning rate not being correctly set after using `LearningRateFinder` callback ([#21068](https://github.com/Lightning-AI/pytorch-lightning/pull/21068)) + --- ## [2.5.3] - 2025-08-13 diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index a5d758f7fff19..b4b61d5cf0f93 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -276,17 +276,10 @@ def _lr_find( if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() - # Update lr attr if required + # Update results across ranks lr_finder.results = trainer.strategy.broadcast(lr_finder.results) - if update_attr: - lr = lr_finder.suggestion() - - # TODO: log lr.results to self.logger - if lr is not None: - lightning_setattr(model, attr_name, lr) - log.info(f"Learning rate set to {lr}") - # Restore initial state of model + # Restore initial state of model (this will also restore the original optimizer state) trainer._checkpoint_connector.restore(ckpt_path) trainer.strategy.remove_checkpoint(ckpt_path) trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True @@ -294,6 +287,19 @@ def _lr_find( trainer.fit_loop.epoch_loop.val_loop._combined_loader = None trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit trainer.fit_loop.setup_data() + + # Apply LR suggestion after restoring so it persists for the real training run + # When used as a callback, the suggestion would otherwise be lost due to checkpoint restore + if update_attr: + lr = lr_finder.suggestion() + if lr is not None: + # update the attribute on the LightningModule (e.g., lr or learning_rate) + lightning_setattr(model, attr_name, lr) + # also update the currently active optimizer(s) so training continues with the suggested LR + for opt in trainer.optimizers or []: + for pg in opt.param_groups: + pg["lr"] = lr + log.info(f"Learning rate set to {lr}") return lr_finder diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index e2d1b6bd4ee84..69575a351b0a5 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -619,6 +619,78 @@ def test_gradient_correctness(): assert abs(suggestion - math.pi) < 1e-2, "Suggestion should be close to pi for this synthetic example" +def test_lr_finder_callback_applies_lr_after_restore(tmp_path): + """LearningRateFinder used as a callback should apply its suggested LR to the optimizer used after state + restoration.""" + + import torch.nn as nn + import torch.nn.functional as F + from torch.utils.data import DataLoader, Dataset + + from lightning.pytorch.callbacks import LearningRateMonitor + + class RandomDataset(Dataset): + def __init__(self, n: int = 256, in_dim: int = 28 * 28): + self.x = torch.randn(n, in_dim) + self.y = torch.randn(n, in_dim) + + def __len__(self) -> int: + return len(self.x) + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + class TinyAE(BoringModel): + def __init__(self, lr: float = 1e-5): + super().__init__() + self.save_hyperparameters() + self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)) + self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)) + + def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: + x, y = batch + z = self.encoder(x) + x_hat = self.decoder(z) + loss = F.mse_loss(x_hat, y) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.hparams.lr) + + seed_everything(123) + + ds = RandomDataset(n=512) + train_loader = DataLoader(ds, batch_size=64, shuffle=False) + + model = TinyAE(lr=1e-5) + + lr_finder_cb = LearningRateFinder() # default update_attr=True should apply suggestion + lr_monitor = LearningRateMonitor(logging_interval="step") + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=2, + callbacks=[lr_finder_cb, lr_monitor], + enable_model_summary=False, + enable_progress_bar=False, + log_every_n_steps=1, + ) + + trainer.fit(model, train_loader) + assert model.hparams.lr is not None + # Ensure LR Finder produced a suggestion for this setup; if not, the test can't assert application + assert lr_finder_cb.optimal_lr is not None, "LR Finder should have computed results" + suggestion = lr_finder_cb.optimal_lr.suggestion() + assert suggestion is not None, "LR Finder should produce a suggestion for this setup" + + # Verify that the optimizer used for subsequent training has the suggested LR applied + assert trainer.optimizers, "Trainer should have an optimizer after fit" + current_lr = trainer.optimizers[0].param_groups[0]["lr"] + assert current_lr == pytest.approx(suggestion), ( + f"LR Finder suggestion {suggestion} should be applied to optimizer, but got {current_lr}" + ) + + def test_exponential_vs_linear_mode_gradient_difference(tmp_path): """Test that exponential and linear modes produce different but valid suggestions.