How to design a multi-task architecture? #7729
-
Hello, I'm designing a multi-task architecture for depth estimation and semantic segmentation. Because these are very similar tasks, I can use my existing class LitModel(pl.LightningModel):
def __init__(self, model: nn.Module, hparams):
super().__init__()
# default for segmentation and depth
self.hparams.update(hparams)
self.model = model
# SEGMENTATION STUFF
if segmentation:
self._batch_target_name = 'label'
# VAL STUFF
self._loss_func = CrossEntropyLoss
self._val_loss_func = CrossEntropyLoss
self._confusion_matrix = ConfusionMatrix(num_classes=num_classes, compute_on_step=False)
self._best_metric = 0
self._best_metric_name = 'mIoU'
self._val_reset_metric_list = [self._val_loss_func, self._confusion_matrix]
self._validation_step = self.seg_validation_step
self._validation_epoch_end = self.seg_validation_epoch_end
# DEPTH STUFF
elif depth:
self._batch_target_name = 'depth'
# VAL STUFF
self._loss_func = L1
self._val_loss_func = MSE
self._best_metric = 0
self._best_metric_name = 'RMSE'
self._val_reset_metric_list = [self._val_loss_func]
self._validation_step = self.dpt_validation_step
self._validation_epoch_end = self.dpt_validation_epoch_end
def forward(self, sample):
return self.model(sample)
def training_step(self, batch, batch_idx):
# disassemble batch/samples
image = batch['image']
target_scales = [batch[self._batch_target_name]]
# predict input images
pred_scales = self.model(image)
# calculate losses
losses = self._loss_func(pred_scales, target_scales)
summed_loss = sum(losses)
self.log('train/loss', summed_loss, on_step=False, on_epoch=True) # logs mean of all losses during that epoch
return {'loss': summed_loss}
def validation_step(self, batch, batch_idx):
# disassemble batch/samples
image = batch['image']
gt = batch[self._batch_target_name]
# predict input images
prediction = self.model(image)
return self._validation_step(batch, batch_idx, gt, prediction)
def seg_validation_step(self, batch, batch_idx, gt, prediction):
# calculate segmentation validation losses
def dpt_validation_step(self, batch, batch_idx, gt, prediction):...
# calculate depth validation losses
def on_validation_epoch_start(self) -> None:
for f in self._val_reset_metric_list:
f.reset()
def validation_epoch_end(self, outputs) -> None:
metric = self._validation_epoch_end(outputs)
if self._best_metric < metric:
self.log(f'eval/best_{self._best_metric_name}', metric)
self.log('eval/best_epoch', self.current_epoch)
self._best_metric = metric
def seg_validation_epoch_end(self, outputs) -> metric:
# calculate confusion matrix and mIoU
return mIoU
def dpt_validation_epoch_end(self, outputs) -> metric:
# calculate RMSE
return RMSE
def configure_optimizers(self):... To me, this seems as really bad practice and I hope it doesn't hurt you to much seeing this. I'm coming from TensorFlow and I really really like how clean PyTorch Lightning is. It feels so good to work with it. But what I wrote up there just feels wrong, but it works 😨 |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
First of all: I'm glad you are enjoying lightning :) Coming to your code: It actually doesn't look so bad to me. You're separating functionality for different use cases in different functions (which is perfectly fine). What you could do ( if you really want to) is something like this: class BaseModel(LightningModule):
... # implements all the logic to be shared between the models such as the module logic or something like this
class SegmentationModel(BaseModel):
... # adds all the segmentation-only logic
class DepthModel(BaseModel):
... # adds all the depth-only logic
class CombinedModel(LightningModel):
def __init__(self, model, hparams):
if depth:
model = DepthModel(model, hparams)
else:
model = SegmentationModel(model, hparams)
self.model = model
def training_step(self, *args, **kwargs):
return self.model.training_step(*args, **kwargs)
# do the same for other methods and hooks That way your classes would be a bit more separated and self-contained. That being said, I still think your current approach is perfectly fine |
Beta Was this translation helpful? Give feedback.
First of all: I'm glad you are enjoying lightning :)
Coming to your code: It actually doesn't look so bad to me. You're separating functionality for different use cases in different functions (which is perfectly fine). What you could do ( if you really want to) is something like this: