Link arguments from Datamodule into init_args of lr_scheduler #11628
-
Hey! I'm trying to use class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
# Set lr_scheduler's num_training_steps from datamodule class
parser.link_arguments(
"data",
"lr_scheduler.init_args.num_training_steps",
compute_fn=lambda dm: dm.get_num_training_steps(),
apply_on="instantiate",
) and I get the following error: ValueError: No action for key "lr_scheduler.init_args.num_training_steps". I was wondering if such thing is possible, or is linking to Code to reproduce
import pytorch_lightning as pl
import torch.nn
from pytorch_lightning.utilities.cli import LR_SCHEDULER_REGISTRY
from pytorch_lightning.utilities.cli import LightningCLI
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, Dataset
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
# Set lr_scheduler's num_training_steps from datamodule class
parser.link_arguments(
"data",
"lr_scheduler.init_args.num_training_steps",
compute_fn=lambda dm: dm.get_num_training_steps(),
apply_on="instantiate",
)
@LR_SCHEDULER_REGISTRY
class WarmupLR(LambdaLR):
def __init__(
self,
optimizer: Optimizer,
warmup_proportion: float,
num_training_steps: int,
last_epoch=-1,
) -> None:
self.num_training_steps = num_training_steps
self.num_warmup_steps = round(num_training_steps * warmup_proportion)
super().__init__(optimizer, lr_lambda=self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, current_step: int) -> float:
if current_step < self.num_warmup_steps:
return float(current_step) / float(max(1, self.num_warmup_steps))
return max(
0.0,
float(self.num_training_steps - current_step)
/ float(max(1, self.num_training_steps - self.num_warmup_steps)),
)
class DataModule(pl.LightningDataModule):
def __init__(self, name):
super().__init__()
self.length = len(name)
def train_dataloader(self):
return DataLoader(Dataset())
def get_num_training_steps(self) -> int:
return self.length
class LitModel(pl.LightningModule):
def __init__(self, num_labels):
super().__init__()
self.num_labels = num_labels
self.nn = torch.nn.Linear(num_labels, num_labels)
def training_step(self, *args, **kwargs):
return
if __name__ == "__main__":
cli = MyLightningCLI(
model_class=LitModel,
datamodule_class=DataModule,
)
data:
name: blablabla
model:
num_labels: 5
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.01
lr_scheduler:
warmup_proportion: 0.1
trainer:
max_epochs: 2 |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
cc: @carmocca |
Beta Was this translation helpful? Give feedback.
-
You need to add an empty The error appears because the So you have two options:
class MyLightningCLI(LightningCLI):
@staticmethod
def link_optimizers_and_lr_schedulers(parser):
# Set lr_scheduler's num_training_steps from datamodule class
parser.link_arguments(
"data",
"lr_scheduler.init_args.num_training_steps",
compute_fn=lambda dm: dm.get_num_training_steps(),
apply_on="instantiate",
)
LightningCLI.link_optimizers_and_lr_schedulers(parser)
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
# Manually add the lr scheduler classes
parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes)
# Set lr_scheduler's num_training_steps from datamodule class
parser.link_arguments(
"data",
"lr_scheduler.init_args.num_training_steps",
compute_fn=lambda dm: dm.get_num_training_steps(),
apply_on="instantiate",
) Also, the config for the scheduler should be: lr_scheduler:
class_path: __main__.WarmupLR
init_args:
warmup_proportion: 0.01
num_training_steps: 1 |
Beta Was this translation helpful? Give feedback.
You need to add an empty
configure_optimizers
method to your model as there's a bug that disallows leaving it unimplemented. It will be fixed with #11672The error appears because the
lr_scheduler
arguments have not been added yet. You can see the order here:https://github.com/PyTorchLightning/pytorch-lightning/blob/86b177ebe5427725b35fde1a8808a7b59b8a277a/pytorch_lightning/utilities/cli.py#L603-L609
So you have two options: