-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathloss.py
More file actions
247 lines (199 loc) · 9.09 KB
/
loss.py
File metadata and controls
247 lines (199 loc) · 9.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import torch
import torch.nn.functional as F
from torch import nn as nn
from torch.autograd import Variable
from torch.nn import MSELoss, SmoothL1Loss, L1Loss
import numpy as np
def compute_per_channel_dice(input, target, epsilon=1e-5, ignore_index=None, weight=None):
# assumes that input is a normalized probability
# input and target shapes must match
assert input.size() == target.size(), "'input' and 'target' must have the same shape"
# mask ignore_index if present
if ignore_index is not None:
mask = target.clone().ne_(ignore_index)
mask.requires_grad = False
input = input * mask
target = target * mask
input = flatten(input)
target = flatten(target)
target = target.float()
input = input.float()
# Compute per channel Dice Coefficient
intersect = (input * target).sum(-1)
if weight is not None:
intersect = weight * intersect
denominator = (input + target).sum(-1)
return 2. * intersect / denominator.clamp(min=epsilon)
class DiceLoss(nn.Module):
"""Computes Dice Loss, which just 1 - DiceCoefficient described above.
Additionally allows per-class weights to be provided.
"""
def __init__(self, epsilon=1e-5, weight=None, ignore_index=None, sigmoid_normalization=True,
skip_last_target=False):
super(DiceLoss, self).__init__()
self.epsilon = epsilon
self.register_buffer('weight', weight)
self.ignore_index = ignore_index
# The output from the network during training is assumed to be un-normalized probabilities and we would
# like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data,
# normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems.
# However if one would like to apply Softmax in order to get the proper probability distribution from the
# output, just specify sigmoid_normalization=False.
if sigmoid_normalization:
self.normalization = nn.Sigmoid()
else:
self.normalization = nn.Softmax(dim=1)
# if True skip the last channel in the target
self.skip_last_target = skip_last_target
def forward(self, input, target):
# Creating one-hot target
target_one_hot = torch.cuda.FloatTensor(input.size())
target_one_hot.zero_()
one_hot_shape = (target.size()[0], 1, target.size()[1], target.size()[2], target.size()[3])
target_one_hot.scatter_(1, target.view(one_hot_shape), 1)
target = target_one_hot
# get probabilities from logits
input = self.normalization(input)
if self.weight is not None:
weight = Variable(self.weight, requires_grad=False)
else:
weight = None
if self.skip_last_target:
target = target[:, :-1, ...]
per_channel_dice = compute_per_channel_dice(input, target, epsilon=self.epsilon, ignore_index=self.ignore_index,
weight=weight)
# Average the Dice score across all channels/classes
return torch.mean(1. - per_channel_dice)
class WeightedCrossEntropyLoss(nn.Module):
"""WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf
"""
def __init__(self, weight=None, ignore_index=-1):
super(WeightedCrossEntropyLoss, self).__init__()
self.register_buffer('weight', weight)
self.ignore_index = ignore_index
def forward(self, input, target):
class_weights = self._class_weights(input)
if self.weight is not None:
weight = Variable(self.weight, requires_grad=False)
class_weights = class_weights * weight
return F.cross_entropy(input, target, weight=class_weights, ignore_index=self.ignore_index)
@staticmethod
def _class_weights(input):
# normalize the input first
input = F.softmax(input, _stacklevel=5)
flattened = flatten(input)
nominator = (1. - flattened).sum(-1)
denominator = flattened.sum(-1)
class_weights = Variable(nominator / denominator, requires_grad=False)
return class_weights
class GeneralizedDiceLoss(nn.Module):
"""Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf
"""
def __init__(self, epsilon=1e-5, weight=None, ignore_index=None, sigmoid_normalization=True):
super(GeneralizedDiceLoss, self).__init__()
self.epsilon = epsilon
self.register_buffer('weight', weight)
self.ignore_index = ignore_index
if sigmoid_normalization:
self.normalization = nn.Sigmoid()
else:
self.normalization = nn.Softmax(dim=1)
def forward(self, input, target):
# get probabilities from logits
input = self.normalization(input)
assert input.size() == target.size(), "'input' and 'target' must have the same shape"
# mask ignore_index if present
if self.ignore_index is not None:
mask = target.clone().ne_(self.ignore_index)
mask.requires_grad = False
input = input * mask
target = target * mask
input = flatten(input)
target = flatten(target)
target = target.float()
target_sum = target.sum(-1)
class_weights = Variable(1. / (target_sum * target_sum).clamp(min=self.epsilon), requires_grad=False)
intersect = (input * target).sum(-1) * class_weights
if self.weight is not None:
weight = Variable(self.weight, requires_grad=False)
intersect = weight * intersect
intersect = intersect.sum()
denominator = ((input + target).sum(-1) * class_weights).sum()
return 1. - 2. * intersect / denominator.clamp(min=self.epsilon)
def flatten(tensor):
"""Flattens a given tensor such that the channel axis is first.
The shapes are transformed as follows:
(N, C, D, H, W) -> (C, N * D * H * W)
"""
C = tensor.size(1)
# new axis order
axis_order = (1, 0) + tuple(range(2, tensor.dim()))
# Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
transposed = tensor.permute(axis_order)
# Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
return transposed.reshape(C, -1)
class CrossEntropyTopK(nn.Module):
def __init__(self, weight=None, top_k=0.5):
super(CrossEntropyTopK, self).__init__()
if weight is None:
self.loss = nn.CrossEntropyLoss(reduction='none')
else:
self.loss = torch.nn.CrossEntropyLoss(weight=weight, reduction='none')
self.top_k = top_k
def forward(self, input, target):
sizes = list(target.size())
num = np.prod(sizes[1:])
loss = self.loss(input, target).view(sizes[0], -1)
u, v = torch.topk(loss, int(self.top_k * num))
return torch.mean(u)
class FocalLoss(nn.Module):
r"""
This criterion is a implemenation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The losses are averaged across observations for each minibatch.
Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
size_average(bool): size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are
instead summed for each minibatch.
"""
def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
super(FocalLoss, self).__init__()
if alpha is None:
self.alpha = Variable(torch.ones(class_num, 1))
else:
if isinstance(alpha, Variable):
self.alpha = alpha
else:
self.alpha = Variable(alpha)
self.gamma = gamma
self.class_num = class_num
self.size_average = size_average
def forward(self, inputs, targets):
N = inputs.size(0)
print(N)
C = inputs.size(1)
P = F.softmax(inputs)
class_mask = inputs.data.new(N, C).fill_(0)
class_mask = Variable(class_mask)
ids = targets.view(-1, 1)
class_mask.scatter_(1, ids.data, 1.)
# print(class_mask)
if inputs.is_cuda and not self.alpha.is_cuda:
self.alpha = self.alpha.cuda()
alpha = self.alpha[ids.data.view(-1)]
probs = (P * class_mask).sum(1).view(-1, 1)
log_p = probs.log()
# print('probs size= {}'.format(probs.size()))
# print(probs)
batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
# print('-----bacth_loss------')
# print(batch_loss)
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss