diff --git a/docs/source-pytorch/cli/lightning_cli_intermediate_2.rst b/docs/source-pytorch/cli/lightning_cli_intermediate_2.rst index db06f85c359bd..f68316128f632 100644 --- a/docs/source-pytorch/cli/lightning_cli_intermediate_2.rst +++ b/docs/source-pytorch/cli/lightning_cli_intermediate_2.rst @@ -201,9 +201,10 @@ If the scheduler you want needs other arguments, add them via the CLI (no need t .. code:: bash - python main.py fit --optimizer=Adam --lr_scheduler=ReduceLROnPlateau --lr_scheduler.monitor=epoch + python main.py fit --optimizer=Adam --lr_scheduler=ReduceLROnPlateau --lr_scheduler.monitor=train_loss -Furthermore, any custom subclass of ``torch.optim.lr_scheduler.LRScheduler`` can be used as learning rate scheduler: +(assuming you have a ``train_loss`` metric logged). Furthermore, any custom subclass of +``torch.optim.lr_scheduler.LRScheduler`` can be used as learning rate scheduler: .. code:: python @@ -212,7 +213,6 @@ Furthermore, any custom subclass of ``torch.optim.lr_scheduler.LRScheduler`` can from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule - class LitLRScheduler(torch.optim.lr_scheduler.CosineAnnealingLR): def step(self): print("⚡", "using LitLRScheduler", "⚡") diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 91247127f6c87..fe4f41b4f4e4d 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -66,6 +66,13 @@ class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): + """Custom ReduceLROnPlateau scheduler that extends PyTorch's ReduceLROnPlateau. + + This class adds a `monitor` attribute to the standard PyTorch ReduceLROnPlateau to specify which metric should be + tracked for learning rate adjustment. + + """ + def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None: super().__init__(optimizer, *args, **kwargs) self.monitor = monitor