Skip to content
Discussion options

You must be logged in to vote

Dear @turian,

Yes, it is possible.

You could do something like this.

class MultiModels(LightningModule):

    def __init__(self, models: List[nn.Module]):
        self.models = models

    def compute_loss(self, model, batch):
        loss = ...
        return loss

    def training_step(self, batch, batch_idx):
        loss = sum(compute_loss(model, batch) for model in self.models)
        return loss


model = MultiModels([resnet50_model, alexnet_model, ...])
dm = ...
trainer.fit(model, dm)

Does this answer your questions ?

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by turian
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment