Skip to content

MisconfigurationException with ReduceLROnPlateau when taking a single stepΒ #8839

@nathancooperjones

Description

@nathancooperjones

πŸ› Bug

When using a learning rate scheduler that requires a monitor, such as ReduceLROnPlateau, and we set up a Trainer to only train a single step with max_steps=1, we get a MisconfigurationException that says the only monitors we can use in the LR scheduler are the ones associated with the training metrics, not those for validation or testing epochs. This change is reltaively new, and seems to be added in the 1.4.0 release around these lines:

https://github.com/PyTorchLightning/pytorch-lightning/blob/a64cc373946a755ce6c3aef57c1be607dfe29a0c/pytorch_lightning/trainer/connectors/optimizer_connector.py#L61-L74

To Reproduce

Interactive notebook to see this bug: https://colab.research.google.com/drive/1gV7gRr9KzVE_oZrGJcX9o2yYngtobNJJ?usp=sharing

A quick example without the notebook is here:

import torch
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset

class BoringModel(LightningModule):

    def __init__(self, monitor: str):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

        self.monitor = monitor

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log(name='train_loss_epoch', value=avg_loss)

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"x": loss}

    def validation_epoch_end(self, outputs) -> None:
        avg_loss = torch.stack([x['x'] for x in outputs]).mean()
        self.log(name='val_loss_epoch', value=avg_loss)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler_dict = {
            'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer),
            'monitor': self.monitor,
        }
        return [optimizer], [lr_scheduler_dict]


def test_x(tmpdir, monitor: str):
    # init model
    model = BoringModel(monitor=monitor)

    # Initialize a trainer
    trainer = pl.Trainer(
        max_steps=1,  # note that only a single step is taken here
    )

    # Train the model
    trainer.fit(model, train, val)

This works without error:

>>> test_x(tmpdir, monitor='train_loss_epoch')

But when we run this code, we get the error:

>>> test_x(tmpdir, monitor='val_loss_epoch')
MisconfigurationException: ReduceLROnPlateau conditioned on metric val_loss_epoch which is not available. Available metrics are: ['train_loss_epoch']. Condition can be set using `monitor` key in lr scheduler dict

Expected behavior

Expected behavior would be that the learning rate scheduler would be able to see all available metrics from validation and testing epochs from the start, and would be able to use that even if we take a step. In the current setup, debugging and testing a pipeline like this by taking a single step does not work, requring instead that all instances of max_steps=1 be replaced with max_epochs=1 to work.

Ideally, both:

>>> test_x(tmpdir, monitor='train_loss_epoch')
>>> test_x(tmpdir, monitor='val_loss_epoch')

would work without error.

Environment

Please copy and paste the output from our environment collection script:
https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/requirements/collect_env_details.py
(For security purposes, please check the contents of the script before running it)

You can get the script and run it with:

wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/requirements/collect_env_details.py
python collect_env_details.py

^ this script no longer exists btw

  • PyTorch Lightning Version (e.g., 1.3.0): 1.4.1
  • PyTorch Version (e.g., 1.8): 1.9.0
  • Python version: 3.8
  • OS (e.g., Linux): Linux
  • CUDA/cuDNN version: N/A
  • GPU models and configuration: N/A
  • How you installed PyTorch (conda, pip, source): pip
  • If compiling from source, the output of torch.__config__.show(): N/A
  • Any other relevant information: N/A

Additional context

Great library, thank you for development on this!! πŸ˜„

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions