Skip to content

LearningRateFinder: Learning Rate is not set correctly with the callback #21030

@cvoscode

Description

@cvoscode

Bug description

When using the LearningRateFinder() in a pl.Trainer callback, it says it sets a learning rate but it seemingly does not. When logging the learning rate with self.log("learning_rate", self.optimizers().param_groups[0]['lr']) or with LearningRateMonitor() the learning rate is still the one that is defined in the pl.LigthningModule not the one "set" with the LearningRateFinder()

pl.Lightning Version:2.5.2
torch version 2.7+cu12.8
using jupyter notebook

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl

import pytorch_lightning as pl
print(pl.__version__)
print(torch.__version__)

class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
        self.lr = 0.00001

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        self.log("learning_rate", self.optimizers().param_groups[0]['lr'])
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

from pytorch_lightning.callbacks import LearningRateFinder,LearningRateMonitor
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])
autoencoder = LitAutoEncoder()
trainer = pl.Trainer(callbacks=[LearningRateFinder(),LearningRateMonitor()],max_epochs=50)
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))

Error messages and logs

epoch,learning_rate,lr-Adam,step,train_loss
,,1e-05,0,
,,1e-05,49,
0,9.999999747378752e-06,,49,0.1070743054151535
,,1e-05,99,
0,9.999999747378752e-06,,99,0.20870453119277954
,,1e-05,149,
0,9.999999747378752e-06,,149,0.11257157474756241
,,1e-05,199,
0,9.999999747378752e-06,,199,0.0726708397269249
,,1e-05,249,
0,9.999999747378752e-06,,249,0.1520737260580063
,,1e-05,299,
0,9.999999747378752e-06,,299,0.14219748973846436
,,1e-05,349,
0,9.999999747378752e-06,,349,0.08788570761680603
,,1e-05,399,
0,9.999999747378752e-06,,399,0.15874463319778442
,,1e-05,449,
0,9.999999747378752e-06,,449,0.14308691024780273
,,1e-05,499,
0,9.999999747378752e-06,,499,0.09943148493766785
,,1e-05,549,
0,9.999999747378752e-06,,549,0.0940447598695755
,,1e-05,599,
0,9.999999747378752e-06,,599,0.06538216769695282
,,1e-05,649,
0,9.999999747378752e-06,,649,0.09586357325315475

Image

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.2):
#- PyTorch Version (e.g., 2.7):
#- Python version (e.g., 3.12.3):
#- OS (e.g., Linux):
#- CUDA/cuDNN version: 12.8
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source): pip/uv

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions