|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +from __future__ import print_function, division |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | + |
| 7 | +class DiceLoss(nn.Module): |
| 8 | + def __init__(self, params = None): |
| 9 | + super(DiceLoss, self).__init__() |
| 10 | + if(params is None): |
| 11 | + self.softmax = True |
| 12 | + else: |
| 13 | + self.softmax = params.get('loss_softmax', True) |
| 14 | + |
| 15 | + def forward(self, loss_input_dict): |
| 16 | + predict = loss_input_dict['prediction'] |
| 17 | + soft_y = loss_input_dict['ground_truth'] |
| 18 | + |
| 19 | + if(isinstance(predict, (list, tuple))): |
| 20 | + predict = predict[0] |
| 21 | + if(self.softmax): |
| 22 | + predict = nn.Softmax(dim = 1)(predict) |
| 23 | + predict = reshape_tensor_to_2D(predict) |
| 24 | + soft_y = reshape_tensor_to_2D(soft_y) |
| 25 | + dice_score = get_classwise_dice(predict, soft_y) |
| 26 | + dice_loss = 1.0 - dice_score.mean() |
| 27 | + return dice_loss |
| 28 | + |
| 29 | +class MumfordShahLoss(nn.Module): |
| 30 | + """ |
| 31 | + Implementation of Mumford Shah Loss in this paper: |
| 32 | + Boah Kim and Jong Chul Ye, Mumford–Shah Loss Functional |
| 33 | + for Image Segmentation With Deep Learning. IEEE TIP, 2019. |
| 34 | + The oringial implementation is availabel at: |
| 35 | + https://github.com/jongcye/CNN_MumfordShah_Loss |
| 36 | + |
| 37 | + currently only 2D version is supported. |
| 38 | + """ |
| 39 | + def __init__(self, params = None): |
| 40 | + super(MumfordShahLoss, self).__init__() |
| 41 | + if(params is None): |
| 42 | + params = {} |
| 43 | + self.softmax = params.get('loss_softmax', True) |
| 44 | + self.penalty = params.get('MumfordShahLoss_penalty', "l1") |
| 45 | + self.grad_w = params.get('MumfordShahLoss_lambda', 1.0) |
| 46 | + |
| 47 | + def get_levelset_loss(self, output, target): |
| 48 | + """ |
| 49 | + output: softmax output of a network |
| 50 | + target: the input image |
| 51 | + """ |
| 52 | + outshape = output.shape |
| 53 | + tarshape = target.shape |
| 54 | + loss = 0.0 |
| 55 | + for ich in range(tarshape[1]): |
| 56 | + target_ = torch.unsqueeze(target[:,ich], 1) |
| 57 | + target_ = target_.expand(tarshape[0], outshape[1], tarshape[2], tarshape[3]) |
| 58 | + pcentroid = torch.sum(target_ * output, (2,3))/torch.sum(output, (2,3)) |
| 59 | + pcentroid = pcentroid.view(tarshape[0], outshape[1], 1, 1) |
| 60 | + plevel = target_ - pcentroid.expand(tarshape[0], outshape[1], tarshape[2], tarshape[3]) |
| 61 | + pLoss = plevel * plevel * output |
| 62 | + loss += torch.sum(pLoss) |
| 63 | + return loss |
| 64 | + |
| 65 | + def get_gradient_loss(self, pred, penalty = "l2"): |
| 66 | + dH = torch.abs(pred[:, :, 1:, :] - pred[:, :, :-1, :]) |
| 67 | + dW = torch.abs(pred[:, :, :, 1:] - pred[:, :, :, :-1]) |
| 68 | + if penalty == "l2": |
| 69 | + dH = dH * dH |
| 70 | + dW = dW * dW |
| 71 | + loss = torch.sum(dH) + torch.sum(dW) |
| 72 | + return loss |
| 73 | + |
| 74 | + def forward(self, loss_input_dict): |
| 75 | + predict = loss_input_dict['prediction'] |
| 76 | + image = loss_input_dict['image'] |
| 77 | + if(isinstance(predict, (list, tuple))): |
| 78 | + predict = predict[0] |
| 79 | + if(self.softmax): |
| 80 | + predict = nn.Softmax(dim = 1)(predict) |
| 81 | + |
| 82 | + pred_shape = list(predict.shape) |
| 83 | + if(len(pred_shape) == 5): |
| 84 | + [N, C, D, H, W] = pred_shape |
| 85 | + new_shape = [N*D, C, H, W] |
| 86 | + predict = torch.transpose(predict, 1, 2) |
| 87 | + predict = torch.reshape(predict, new_shape) |
| 88 | + [N, C, D, H, W] = list(image.shape) |
| 89 | + new_shape = [N*D, C, H, W] |
| 90 | + image = torch.transpose(image, 1, 2) |
| 91 | + image = torch.reshape(image, new_shape) |
| 92 | + loss0 = self.get_levelset_loss(predict, image) |
| 93 | + loss1 = self.get_gradient_loss(predict, self.penalty) |
| 94 | + loss = loss0 + self.grad_w * loss1 |
| 95 | + return loss/torch.numel(predict) |
0 commit comments