How to switch data loaders between epochs while using multiple optimizer #14420
-
Hi, I have two datasets and each dataset will have a corresponding optimizer. Right now I am updating the model parameters per batch per dataset. But how can I iterate
This is what I have for iterating multiple datasets per batch. class Mymodel(LightningModule):
def configure_optimizers(self):
optimizer = optim.Adam(
filter(
lambda p: p.requires_grad,
self.parameters()),
lr=self.lr,)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
)
scheduler_config = {
"scheduler": scheduler,
"monitor": "val/rec/epoch_loss",
"interval": "epoch",
"frequency": 1,
}
pre_optimizer = optim.SGD(
filter(
lambda p: p.requires_grad,
self.parameters()),
lr=self.lr_for_other_modules,)
optimizers = [optimizer, pre_optimizer]
return {"optimizer": optimizer, "lr_scheduler": scheduler_config}
def training_step(self, batch, batch_idx, optimizer_idx):
if optimizer_idx == 0:
# do forward on datasetA
inputs = batch["datasetA"]
if self.is_use_some_modules and optimizer_idx == 1:
# do something else
inputs = batch["datasetB"]
return ...
class Datamodule(LightningDataModule):
def train_dataloader(self):
...
return {"datasetA": A_loader, "datasetB": B_loader} I found a similar question here #3336 and the suggested solution is reloading data loaders every epoch. class Datamodule(LightningDataModule):
def train_dataloader(self):
if self.current_epoch % 2 == 0:
A_loader = ...
return A_loader
else:
B_loader = ...
return B_loader
class Mymodel(LightningModule):
def training_step(self, batch, batch_idx, optimizer_idx):
if self.current_epoch % 2 == 0 and optimizer_idx == 0:
inputs = batch["datasetA"]
elif self.current_epoch % 2 == 1 and optimizer_idx == 1:
inputs = batch["datasetB"]
return ... |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
one simple solution is to use def training_step(self, batch, batch_idx, optimizer_idx):
if (self.current_epoch % 2 == 0 and optimizer_idx == 0) or (self.current_epoch % 2 == 1 and optimizer_idx == 1):
...
return loss in other cases of |
Beta Was this translation helpful? Give feedback.
one simple solution is to use
reload_dataloaders_every_epoch=1
and configuretraining_step
based oncurrent_epoch
.in other cases of
training_step
, it will returnNone
, which will not make any updates to the other optimizer. You might get a warning, but that's fine.your solution is fine as well, but it will load the data from both the dataloader which is not required in each epoch.