Skip to content

Commit 17dde67

Browse files
authored
2509 update focalloss to use sigmoid (7/July) (#2513)
* update focalloss to use sigmoid Signed-off-by: Wenqi Li <[email protected]> * temp tests Signed-off-by: Wenqi Li <[email protected]> * fixes tests Signed-off-by: Wenqi Li <[email protected]>
1 parent 0711e04 commit 17dde67

File tree

3 files changed

+70
-32
lines changed

3 files changed

+70
-32
lines changed

monai/losses/focal_loss.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
class FocalLoss(_Loss):
2424
"""
25-
Reimplementation of the Focal Loss described in:
25+
Reimplementation of the Focal Loss (with a build-in sigmoid activation) described in:
2626
2727
- "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017
2828
- "AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy",
@@ -77,12 +77,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
7777
"""
7878
Args:
7979
input: the shape should be BNH[WD], where N is the number of classes.
80-
The input should be the original logits since it will be transferred by
81-
`F.log_softmax` in the forward function.
80+
The input should be the original logits since it will be transformed by
81+
a sigmoid in the forward function.
8282
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
8383
8484
Raises:
85-
AssertionError: When input and target (after one hot transform if setted)
85+
ValueError: When input and target (after one hot transform if set)
8686
have different shapes.
8787
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
8888
ValueError: When ``self.weight`` is a sequence and the length is not equal to the
@@ -107,7 +107,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
107107
input = input[:, 1:]
108108

109109
if target.shape != input.shape:
110-
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")
110+
raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")
111111

112112
i = input
113113
t = target
@@ -117,10 +117,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
117117
i = i.reshape(b, n, -1)
118118
t = t.reshape(b, n, -1)
119119

120-
# Compute the log proba.
121-
logpt = F.log_softmax(i, dim=1)
122-
# Get the proba
123-
pt = torch.exp(logpt) # B,H*W or B,N,H*W
120+
# computing binary cross entropy with logits
121+
# see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231
122+
max_val = (-i).clamp(min=0)
123+
ce = i - i * t + max_val + ((-max_val).exp() + (-i - max_val).exp()).log()
124124

125125
if self.weight is not None:
126126
class_weight: Optional[torch.Tensor] = None
@@ -142,11 +142,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
142142
at = class_weight[None, :, None] # N => 1,N,1
143143
at = at.expand((t.size(0), -1, t.size(2))) # 1,N,1 => B,N,H*W
144144
# Multiply the log proba by their weights.
145-
logpt = logpt * at
145+
ce = ce * at
146146

147147
# Compute the loss mini-batch.
148-
weight = torch.pow(-pt + 1.0, self.gamma)
149-
loss = torch.mean(-weight * t * logpt, dim=-1)
148+
# (1-p_t)^gamma * log(p_t) with reduced chance of overflow
149+
p = F.logsigmoid(-i * (t * 2.0 - 1.0))
150+
loss = torch.mean((p * self.gamma).exp() * ce, dim=-1)
151+
150152
if self.reduction == LossReduction.SUM.value:
151153
return loss.sum()
152154
if self.reduction == LossReduction.NONE.value:

tests/test_focal_loss.py

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,32 @@
2222

2323
class TestFocalLoss(unittest.TestCase):
2424
def test_consistency_with_cross_entropy_2d(self):
25-
# For gamma=0 the focal loss reduces to the cross entropy loss
26-
focal_loss = FocalLoss(to_onehot_y=True, gamma=0.0, reduction="mean", weight=1.0)
27-
ce = nn.CrossEntropyLoss(reduction="mean")
25+
"""For gamma=0 the focal loss reduces to the cross entropy loss"""
26+
focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction="mean", weight=1.0)
27+
ce = nn.BCEWithLogitsLoss(reduction="mean")
2828
max_error = 0
2929
class_num = 10
3030
batch_size = 128
3131
for _ in range(100):
3232
# Create a random tensor of shape (batch_size, class_num, 8, 4)
3333
x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True)
3434
# Create a random batch of classes
35-
l = torch.randint(low=0, high=class_num, size=(batch_size, 1, 8, 4))
35+
l = torch.randint(low=0, high=2, size=(batch_size, class_num, 8, 4)).float()
3636
if torch.cuda.is_available():
3737
x = x.cuda()
3838
l = l.cuda()
3939
output0 = focal_loss(x, l)
40-
output1 = ce(x, l[:, 0]) / class_num
40+
output1 = ce(x, l)
4141
a = float(output0.cpu().detach())
4242
b = float(output1.cpu().detach())
4343
if abs(a - b) > max_error:
4444
max_error = abs(a - b)
4545
self.assertAlmostEqual(max_error, 0.0, places=3)
4646

4747
def test_consistency_with_cross_entropy_2d_onehot_label(self):
48-
# For gamma=0 the focal loss reduces to the cross entropy loss
49-
focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction="mean")
50-
ce = nn.CrossEntropyLoss(reduction="mean")
48+
"""For gamma=0 the focal loss reduces to the cross entropy loss"""
49+
focal_loss = FocalLoss(to_onehot_y=True, gamma=0.0, reduction="mean")
50+
ce = nn.BCEWithLogitsLoss(reduction="mean")
5151
max_error = 0
5252
class_num = 10
5353
batch_size = 128
@@ -59,18 +59,18 @@ def test_consistency_with_cross_entropy_2d_onehot_label(self):
5959
if torch.cuda.is_available():
6060
x = x.cuda()
6161
l = l.cuda()
62-
output0 = focal_loss(x, one_hot(l, num_classes=class_num))
63-
output1 = ce(x, l[:, 0]) / class_num
62+
output0 = focal_loss(x, l)
63+
output1 = ce(x, one_hot(l, num_classes=class_num))
6464
a = float(output0.cpu().detach())
6565
b = float(output1.cpu().detach())
6666
if abs(a - b) > max_error:
6767
max_error = abs(a - b)
6868
self.assertAlmostEqual(max_error, 0.0, places=3)
6969

7070
def test_consistency_with_cross_entropy_classification(self):
71-
# for gamma=0 the focal loss reduces to the cross entropy loss
71+
"""for gamma=0 the focal loss reduces to the cross entropy loss"""
7272
focal_loss = FocalLoss(to_onehot_y=True, gamma=0.0, reduction="mean")
73-
ce = nn.CrossEntropyLoss(reduction="mean")
73+
ce = nn.BCEWithLogitsLoss(reduction="mean")
7474
max_error = 0
7575
class_num = 10
7676
batch_size = 128
@@ -84,19 +84,43 @@ def test_consistency_with_cross_entropy_classification(self):
8484
x = x.cuda()
8585
l = l.cuda()
8686
output0 = focal_loss(x, l)
87-
output1 = ce(x, l[:, 0]) / class_num
87+
output1 = ce(x, one_hot(l, num_classes=class_num))
8888
a = float(output0.cpu().detach())
8989
b = float(output1.cpu().detach())
9090
if abs(a - b) > max_error:
9191
max_error = abs(a - b)
9292
self.assertAlmostEqual(max_error, 0.0, places=3)
9393

94+
def test_consistency_with_cross_entropy_classification_01(self):
95+
# for gamma=0.1 the focal loss differs from the cross entropy loss
96+
focal_loss = FocalLoss(to_onehot_y=True, gamma=0.1, reduction="mean")
97+
ce = nn.BCEWithLogitsLoss(reduction="mean")
98+
max_error = 0
99+
class_num = 10
100+
batch_size = 128
101+
for _ in range(100):
102+
# Create a random scores tensor of shape (batch_size, class_num)
103+
x = torch.rand(batch_size, class_num, requires_grad=True)
104+
# Create a random batch of classes
105+
l = torch.randint(low=0, high=class_num, size=(batch_size, 1))
106+
l = l.long()
107+
if torch.cuda.is_available():
108+
x = x.cuda()
109+
l = l.cuda()
110+
output0 = focal_loss(x, l)
111+
output1 = ce(x, one_hot(l, num_classes=class_num))
112+
a = float(output0.cpu().detach())
113+
b = float(output1.cpu().detach())
114+
if abs(a - b) > max_error:
115+
max_error = abs(a - b)
116+
self.assertNotAlmostEqual(max_error, 0.0, places=3)
117+
94118
def test_bin_seg_2d(self):
95119
# define 2d examples
96120
target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
97121
# add another dimension corresponding to the batch (batch size = 1 here)
98122
target = target.unsqueeze(0) # shape (1, H, W)
99-
pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float()
123+
pred_very_good = 100 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() - 50.0
100124

101125
# initialize the mean dice loss
102126
loss = FocalLoss(to_onehot_y=True)
@@ -112,7 +136,7 @@ def test_empty_class_2d(self):
112136
target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
113137
# add another dimension corresponding to the batch (batch size = 1 here)
114138
target = target.unsqueeze(0) # shape (1, H, W)
115-
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()
139+
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0
116140

117141
# initialize the mean dice loss
118142
loss = FocalLoss(to_onehot_y=True)
@@ -128,7 +152,7 @@ def test_multi_class_seg_2d(self):
128152
target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]])
129153
# add another dimension corresponding to the batch (batch size = 1 here)
130154
target = target.unsqueeze(0) # shape (1, H, W)
131-
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()
155+
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0
132156
# initialize the mean dice loss
133157
loss = FocalLoss(to_onehot_y=True)
134158
loss_onehot = FocalLoss(to_onehot_y=False)
@@ -159,7 +183,7 @@ def test_bin_seg_3d(self):
159183
# add another dimension corresponding to the batch (batch size = 1 here)
160184
target = target.unsqueeze(0) # shape (1, H, W, D)
161185
target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3) # test one hot
162-
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float()
186+
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float() - 500.0
163187

164188
# initialize the mean dice loss
165189
loss = FocalLoss(to_onehot_y=True)
@@ -173,6 +197,19 @@ def test_bin_seg_3d(self):
173197
focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu())
174198
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
175199

200+
def test_foreground(self):
201+
background = torch.ones(1, 1, 5, 5)
202+
foreground = torch.zeros(1, 1, 5, 5)
203+
target = torch.cat((background, foreground), dim=1)
204+
input = torch.cat((background, foreground), dim=1)
205+
target[:, 0, 2, 2] = 0
206+
target[:, 1, 2, 2] = 1
207+
208+
fgbg = FocalLoss(to_onehot_y=False, include_background=True)(input, target)
209+
fg = FocalLoss(to_onehot_y=False, include_background=False)(input, target)
210+
self.assertAlmostEqual(float(fgbg.cpu()), 0.1116, places=3)
211+
self.assertAlmostEqual(float(fg.cpu()), 0.1733, places=3)
212+
176213
def test_ill_opts(self):
177214
chn_input = torch.ones((1, 2, 3))
178215
chn_target = torch.ones((1, 2, 3))
@@ -182,7 +219,7 @@ def test_ill_opts(self):
182219
def test_ill_shape(self):
183220
chn_input = torch.ones((1, 2, 3))
184221
chn_target = torch.ones((1, 3))
185-
with self.assertRaisesRegex(AssertionError, ""):
222+
with self.assertRaisesRegex(ValueError, ""):
186223
FocalLoss(reduction="mean")(chn_input, chn_target)
187224

188225
def test_ill_class_weight(self):

tests/test_masked_loss.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"to_onehot_y": True,
3333
"reduction": "sum",
3434
},
35-
[(12.105497, 18.805185), (10.636354, 6.3138)],
35+
[(14.538666, 20.191753), (13.17672, 8.251623)],
3636
],
3737
]
3838

@@ -50,7 +50,6 @@ def test_shape(self, input_param, expected_val):
5050
label = torch.randint(low=0, high=2, size=size)
5151
label = torch.argmax(label, dim=1, keepdim=True)
5252
pred = torch.randn(size)
53-
print(label[0, 0, 0])
5453
result = MaskedLoss(**input_param)(pred, label, None)
5554
out = result.detach().cpu().numpy()
5655
checked = np.allclose(out, expected_val[0][0]) or np.allclose(out, expected_val[0][1])

0 commit comments

Comments
 (0)