Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions src/lightning/pytorch/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,24 +276,30 @@ def _lr_find(
if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()

# Update lr attr if required
# Update results across ranks
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
if update_attr:
lr = lr_finder.suggestion()

# TODO: log lr.results to self.logger
if lr is not None:
lightning_setattr(model, attr_name, lr)
log.info(f"Learning rate set to {lr}")

# Restore initial state of model
# Restore initial state of model (this will also restore the original optimizer state)
trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
trainer.fit_loop.setup_data()

# Apply LR suggestion after restoring so it persists for the real training run
# When used as a callback, the suggestion would otherwise be lost due to checkpoint restore
if update_attr:
lr = lr_finder.suggestion()
if lr is not None:
# update the attribute on the LightningModule (e.g., lr or learning_rate)
lightning_setattr(model, attr_name, lr)
# also update the currently active optimizer(s) so training continues with the suggested LR
for opt in trainer.optimizers or []:
for pg in opt.param_groups:
pg["lr"] = lr
log.info(f"Learning rate set to {lr}")
return lr_finder


Expand Down
76 changes: 76 additions & 0 deletions tests/tests_pytorch/tuner/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,82 @@ def test_gradient_correctness():
assert abs(suggestion - math.pi) < 1e-2, "Suggestion should be close to pi for this synthetic example"


def test_lr_finder_callback_applies_lr_after_restore(tmp_path):
"""LearningRateFinder used as a callback should apply its suggested LR to the optimizer used after state
restoration.

This currently fails: the LR is set during the search but lost after restore.

"""

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from lightning.pytorch.callbacks import LearningRateMonitor

class RandomDataset(Dataset):
def __init__(self, n: int = 256, in_dim: int = 28 * 28):
self.x = torch.randn(n, in_dim)
self.y = torch.randn(n, in_dim)

def __len__(self) -> int:
return len(self.x)

def __getitem__(self, idx):
return self.x[idx], self.y[idx]

class TinyAE(BoringModel):
def __init__(self, lr: float = 1e-5):
super().__init__()
self.save_hyperparameters()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
x, y = batch
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, y)
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

seed_everything(123)

ds = RandomDataset(n=512)
train_loader = DataLoader(ds, batch_size=64, shuffle=False)

model = TinyAE(lr=1e-5)

lr_finder_cb = LearningRateFinder() # default update_attr=True should apply suggestion
lr_monitor = LearningRateMonitor(logging_interval="step")

trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=2,
callbacks=[lr_finder_cb, lr_monitor],
enable_model_summary=False,
enable_progress_bar=False,
log_every_n_steps=1,
)

trainer.fit(model, train_loader)

# Ensure LR Finder produced a suggestion for this setup; if not, the test can't assert application
assert lr_finder_cb.optimal_lr is not None, "LR Finder should have computed results"
suggestion = lr_finder_cb.optimal_lr.suggestion()
assert suggestion is not None, "LR Finder should produce a suggestion for this setup"

# Verify that the optimizer used for subsequent training has the suggested LR applied
assert trainer.optimizers, "Trainer should have an optimizer after fit"
current_lr = trainer.optimizers[0].param_groups[0]["lr"]
assert current_lr == pytest.approx(suggestion), (
f"LR Finder suggestion {suggestion} should be applied to optimizer, but got {current_lr}"
)


def test_exponential_vs_linear_mode_gradient_difference(tmp_path):
"""Test that exponential and linear modes produce different but valid suggestions.

Expand Down
Loading