Optimization in a dual encoder LitModel #9807
-
Hello, Currently, I am working in a Lit Model, which has two encoders. Each of them has its optimizer, scheduler, and loss, as shown below: import importlib
import torch
from pytorch_lightning.core.lightning import LightningModule
from hydra.utils import instantiate
from source.metric.ULMRRMetric import ULMRRMetric
class LitModel(LightningModule):
def __init__(self, hparams):
super(LitModel, self).__init__()
self.save_hyperparameters(hparams)
# encoders
self.x1_encoder = instantiate(hparams.x1_encoder)
self.x2_encoder = instantiate(hparams.x2_encoder)
# loss function
self.x1_loss = instantiate(hparams.x1_loss)
self.x2_loss = instantiate(hparams.x2_loss)
def forward(self, x1, x2):
x1_repr = self.x1_encoder(x1)
x2_repr = self.x2_encoder(x2)
return x1_repr, x2_repr
def training_step(self, batch, batch_idx, optimizer_idx):
x1, x2 = batch["x1"], batch["x2"]
x1_repr, x2_repr = self(x1, x2)
x1_loss=self.x1_loss(x1_repr, x2_repr)
x2_loss = self.x2_loss(x1_repr, x2_repr)
# what to return here?
return
def validation_step(self, batch, batch_idx):
x1, x2 = batch["x1"], batch["x2"]
x1_repr, x2_repr = self(x1, x2)
self.log("val_x1_LOSS", self.x1_loss(x1_repr, x2_repr), prog_bar=True)
self.log("val_x2_LOSS", self.x2_loss(x1_repr, x2_repr), prog_bar=True)
# Alternating schedule for optimizer steps
def optimizer_step(
self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure,
on_tpu=False, using_native_amp=False, using_lbfgs=False,
):
# update x1 encoder every even step
if optimizer_idx == 0:
if batch_idx % 2 == 0:
optimizer.step(closure=optimizer_closure)
# update x2 encoder every odd step
if optimizer_idx == 1:
if batch_idx % 2 != 0:
optimizer.step(closure=optimizer_closure)
def configure_optimizers(self):
# optimizers
optimizers = [
torch.optim.AdamW(self.x1_encoder.parameters(), lr=self.hparams.lr, betas=(0.9, 0.999), eps=1e-08,
weight_decay=self.hparams.weight_decay, amsgrad=True),
torch.optim.AdamW(self.x2_encoder.parameters(), lr=self.hparams.lr, betas=(0.9, 0.999), eps=1e-08,
weight_decay=self.hparams.weight_decay, amsgrad=True)
]
# schedulers
step_size_up = round(0.03 * self.num_training_steps)
schedulers = [
torch.optim.lr_scheduler.CyclicLR(optimizers[0], mode='triangular2', base_lr=self.hparams.base_lr,
max_lr=self.hparams.max_lr, step_size_up=step_size_up,
cycle_momentum=False),
torch.optim.lr_scheduler.CyclicLR(optimizers[1], mode='triangular2', base_lr=self.hparams.base_lr,
max_lr=self.hparams.max_lr, step_size_up=step_size_up,
cycle_momentum=False)
]
return optimizers, schedulers
@property
def num_training_steps(self) -> int:
"""Total training steps inferred from datamodule and number of epochs."""
steps_per_epochs = len(self.train_dataloader()) / self.trainer.accumulate_grad_batches
max_epochs = self.trainer.max_epochs
return steps_per_epochs * max_epochs My intention is to update each encoder in alternate steps (even steps: I appreciate any help you can provide. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
I see two ways. I think your example is quite simple so it does not matter which way you choose in the end: 1) Automatic Optimization: def training_step(self, batch, batch_idx, optimizer_idx):
x1, x2 = batch["x1"], batch["x2"]
if optimizer_idx == 0;
x1_repr = self.x1_encoder(x1)
x1_loss=self.x1_loss(x1_repr, x2_repr)
return x1_loss
if optimizer_idx == 1:
x2_repr = ...
return x2_loss
def configure_optimizers(self):
return [
{"optimizer": torch.optim.AdamW(self.x1_encoder.parameters(), ...), "frequency": 1},
{"optimizer": torch.optim.AdamW(self.x2_encoder.parameters(), ...), "frequency": 1},
] (and delete your overridden optimizer step method) 2) Manual Optimization:def __init__(self, hparams):
super().__init__()
self.automatic_optimization = False
...
def training_step(self, batch, batch_idx, optimizer_idx):
x1, x2 = batch["x1"], batch["x2"]
opt0, opt1 = self.optimizers()
if batch_idx % 2 == 0:
loss = ...
opt0.zero_grad()
loss.backward()
opt0.step()
else:
loss =
opt1.zero_grad()
loss.backward()
opt1.step() Reference: https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html#manual-optimization Note: I converted this issue to a GitHub discussion article as this is the primary forum for implementation help questions :) |
Beta Was this translation helpful? Give feedback.
I see two ways. I think your example is quite simple so it does not matter which way you choose in the end:
1) Automatic Optimization: