Skip to content
Discussion options

You must be logged in to vote

well, there are multiple ways:

  1. if your max_epochs is consistent across all the tasks:
class LitModel(LightningModule):
    def on_train_epoch_start(self):
        if current_epoch == 0 or (current_epoch + 1) % self.trainer.reload_dataloaders_every_n_epochs == 0:
            # update model parameters


max_epochs_n_tasks = max_epochs * n_tasks
trainer = Trainer(max_epochs=max_epochs_n_tasks, reload_dataloaders_every_n_epochs=max_epochs)
model = LitModel()

# inject the update task counter logic inside datamodule
dm = RandSplitCIFAR100DataModule(...)
trainer.fit(model, datamodule=dm)
  1. create an explicit loop
def init_trainer(...):
    trainer = Trainer(max_epochs=max_epochs, ...)
    return

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@prateeky2806
Comment options

@rohitgr7
Comment options

@prateeky2806
Comment options

@rohitgr7
Comment options

Answer selected by prateeky2806
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment