Skip to content

Wrong LR scheduler behaviour when using AMP #16228

@milesial

Description

@milesial

Bug description

When training with native AMP and a LR scheduler, we get a warning that indicates that a LR step has been taken when an optimizer step was skipped (expected at the beginning of the training with native AMP):

This can be fixed by wrapping these lines https://github.com/Lightning-AI/lightning/blob/574a9516012b4ab778254055c537f5d57e8e694f/src/pytorch_lightning/core/module.py#L1589-L1592

in if hasattr(optimizer, '_step_count') and optimizer._step_count > 0.

Fix proposed in #16229

How to reproduce the bug

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum() * 100000000
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        opt = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        sched = torch.optim.lr_scheduler.StepLR(opt, 10)
        return {"optimizer": opt, 'lr_scheduler': {"scheduler": sched, "interval": "step"}}

def run():
    train_data = DataLoader(RandomDataset(32, 32), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        accelerator='gpu',
        devices=1,
        default_root_dir=os.getcwd(),
        num_sanity_val_steps=0,
        max_epochs=1,
        precision=16,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()

Error messages and logs

/usr/local/lib/python3.8/dist-packages/torch/optim/lr_scheduler.py:138: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

Environment

No response

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    duplicateThis issue or pull request already exists

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions