How to implement a deep ensemble #8505
-
I am looking to implement class DeepEnsemble(LightningModule):
def __init__(self, cfg):
super().__init__(cfg)
self.net = nn.ModuleList([configure_network(self.cfg) for _ in range(self.cfg.METHOD.ENSEMBLE)])
def configure_optimizers(self):
return [torch.optim.Adam(net.parameters(), lr=self.cfg.SOLVER.LR) for net in self.net]
def forward(self, x):
x = [net.forward(x) for net in self.net]
return x
def training_step(self, batch, batch_idx, optimizer_idx):
image, label = batch["image"], batch["label"]
logits = self.forward(image)
loss = [self.criterion(logit, label) for logit in logits]
mean_logit = torch.stack(logits, dim=-1).mean(dim=-1)
metrics = self.log_metrics(mean_logit, label, 'train')
return loss
def validation_step(self, batch, batch_idx):
image, label = batch["image"], batch["label"]
logits = self.forward(image)
mean_logit = torch.stack(logits, dim=-1).mean(dim=-1)
metrics = self.log_metrics(mean_logit, label, 'val')
return metrics[self.cfg.CKPT.MONITOR]
def test_step(self, batch, batch_idx):
pass I have In addition It would be nice to have all forward passes done in parallel instead of sequential like in this list comprehension. So what is the most elegant way to train an ensemble and still access all predictions for metric logging together? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
I see two potential options.
|
Beta Was this translation helpful? Give feedback.
-
Here a solution with caching predictions: class DeepEnsemble(pl.LightningModule):
def __init__(self, cfg):
super().__init__(cfg)
self.net = nn.ModuleList([configure_network(self.cfg) for _ in range(self.cfg.METHOD.ENSEMBLE)])
self.cache_preds = []
def configure_optimizers(self):
return [torch.optim.Adam(net.parameters(), lr=self.cfg.SOLVER.LR) for net in self.net]
def forward(self, x, idx = None):
if idx is None:
x = torch.stack([net.forward(x) for net in self.net], dim=-1)
else:
x = self.net[idx].forward(x)
return x
def training_step(self, batch, batch_idx, optimizer_idx):
image, label = batch["image"], batch["label"]
logits = self.forward(image, optimizer_idx)
self.cache_preds.append(logits.detach())
loss = self.criterion(logits, label)
if optimizer_idx == self.cfg.METHOD.ENSEMBLE - 1:
logits = torch.stack(self.cache_preds, dim=-1)
mean_logit = logits.mean(dim=-1)
all_loss = self.log_loss(mean_logit, label, 'train')
metrics = self.log_metrics(mean_logit, label, 'train')
self.cache_preds, self.cache_loss = [], []
return loss
def validation_step(self, batch, batch_idx):
image, label = batch["image"], batch["label"]
logits = self.forward(image)
mean_logit = logits.mean(dim=-1)
loss = self.log_loss(mean_logit, label, 'val')
metrics = self.log_metrics(mean_logit, label, 'val')
return metrics[self.cfg.CKPT.MONITOR] |
Beta Was this translation helpful? Give feedback.
I see two potential options.
cache the forward output for a specific batch idx. Check the automatic optimization flow: https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html#automatic-optimization
Use manual optimization. https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html#manual-optimization