Skip to content

Commit 96a52d3

Browse files
committed
numeric stability for ce loss
1 parent a340e60 commit 96a52d3

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

pymic/loss/cls/ce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def forward(self, loss_input_dict):
3838
predict = loss_input_dict['prediction']
3939
labels = loss_input_dict['ground_truth']
4040
# for numeric stability
41-
predict = nn.Sigmoid()(predict) * 0.999 + 0.0005
41+
predict = nn.Sigmoid()(predict) * 0.999 + 5e-4
4242
loss = - labels * torch.log(predict) - (1 - labels) * torch.log( 1 - predict)
4343
loss = loss.mean()
4444
return loss

pymic/loss/seg/ce.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def forward(self, loss_input_dict):
2525
predict = reshape_tensor_to_2D(predict)
2626
soft_y = reshape_tensor_to_2D(soft_y)
2727

28+
# for numeric stability
29+
predict = predict * 0.999 + 5e-4
2830
ce = - soft_y* torch.log(predict)
2931
if(self.enable_cls_weight):
3032
if(cls_w is None):

0 commit comments

Comments
 (0)