Skip to content

Commit 363cca9

Browse files
committed
update loss function
1 parent fec928b commit 363cca9

File tree

11 files changed

+274
-436
lines changed

11 files changed

+274
-436
lines changed

examples/JSRT/config/evaluation.cfg

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@ ground_truth_label_convert_source = None
1111
ground_truth_label_convert_target = None
1212
segmentation_label_convert_source = None
1313
segmentation_label_convert_target = None
14+
15+
16+
17+

examples/JSRT/config/train_test.cfg

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,12 @@ bilinear = True
5151

5252
[training]
5353
# device name" cuda:n or cpu
54-
device_name = cuda:0
54+
device_name = cuda:1
5555

5656
batch_size = 4
57-
loss_function = dice_loss
58-
class_weight = [0.5, 1.0]
57+
loss_type = DiceLoss
58+
DiceLoss_enable_pixel_weight = False
59+
DiceLoss_enable_class_weight = False
5960

6061
# for optimizers
6162
optimizer = Adam

examples/fetal_hc/config/train_test.cfg

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ bilinear = True
6262
device_name = cuda:0
6363

6464
batch_size = 4
65-
loss_function = dice_loss
66-
class_weight = [0.5, 1.0]
65+
loss_type = DiceLoss
66+
DiceLoss_enable_pixel_weight = False
67+
DiceLoss_enable_class_weight = False
6768

6869
# for optimizers
6970
optimizer = Adam

examples/prostate/config/train_test.cfg

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ trilinear = True
5959
device_name = cuda:0
6060

6161
batch_size = 4
62-
loss_function = dice_loss
63-
class_weight = [0.5, 1.0]
62+
loss_type = DiceLoss
63+
DiceLoss_enable_pixel_weight = False
64+
DiceLoss_enable_class_weight = False
6465

6566
# for optimizers
6667
optimizer = Adam

pymic/loss/__init__.py

Whitespace-only changes.

pymic/loss/ce.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
import torch
5+
import torch.nn as nn
6+
from pymic.loss.util import reshape_tensor_to_2D
7+
8+
class CrossEntropyLoss(nn.Module):
9+
def __init__(self, params):
10+
super(CrossEntropyLoss, self).__init__()
11+
self.enable_pix_weight = params['CrossEntropyLoss_Enable_Pixel_Weight'.lower()]
12+
self.enable_cls_weight = params['CrossEntropyLoss_Enable_Class_Weight'.lower()]
13+
14+
def forward(self, loss_input_dict):
15+
predict = loss_input_dict['prediction']
16+
soft_y = loss_input_dict['ground_truth']
17+
pix_w = loss_input_dict['pixel_weight']
18+
cls_w = loss_input_dict['class_weight']
19+
softmax = loss_input_dict['softmax']
20+
21+
if(softmax):
22+
predict = nn.Softmax(dim = 1)(predict)
23+
predict = reshape_tensor_to_2D(predict)
24+
soft_y = reshape_tensor_to_2D(soft_y)
25+
26+
ce = - soft_y* torch.log(predict)
27+
if(self.enable_cls_weight):
28+
if(cls_w is None):
29+
raise ValueError("Class weight is enabled but not defined")
30+
ce = torch.sum(ce * cls_w, dim = 1)
31+
else:
32+
ce = torch.sum(ce, dim = 1)
33+
if(self.enable_pix_weight):
34+
if(pix_w is None):
35+
raise ValueError("Pixel weight is enabled but not defined")
36+
pix_w = reshape_tensor_to_2D(pix_w)
37+
ce = torch.sum(ce * pix_w) / torch.sum(pix_w)
38+
else:
39+
ce = torch.mean(ce)
40+
return ce

pymic/loss/dice.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
import torch
5+
import torch.nn as nn
6+
from pymic.loss.util import reshape_tensor_to_2D, get_classwise_dice
7+
8+
class DiceLoss(nn.Module):
9+
def __init__(self, params):
10+
super(DiceLoss, self).__init__()
11+
self.enable_pix_weight = params['DiceLoss_Enable_Pixel_Weight'.lower()]
12+
self.enable_cls_weight = params['DiceLoss_Enable_Class_Weight'.lower()]
13+
14+
def forward(self, loss_input_dict):
15+
predict = loss_input_dict['prediction']
16+
soft_y = loss_input_dict['ground_truth']
17+
pix_w = loss_input_dict['pixel_weight']
18+
cls_w = loss_input_dict['class_weight']
19+
softmax = loss_input_dict['softmax']
20+
21+
if(softmax):
22+
predict = nn.Softmax(dim = 1)(predict)
23+
predict = reshape_tensor_to_2D(predict)
24+
soft_y = reshape_tensor_to_2D(soft_y)
25+
26+
if(self.enable_pix_weight):
27+
if(pix_w is None):
28+
raise ValueError("Pixel weight is enabled but not defined")
29+
pix_w = reshape_tensor_to_2D(pix_w)
30+
dice_score = get_classwise_dice(predict, soft_y, pix_w)
31+
if(self.enable_cls_weight):
32+
if(cls_w is None):
33+
raise ValueError("Class weight is enabled but not defined")
34+
weighted_dice = dice_score * cls_w
35+
avg_dice = weighted_dice.sum() / cls_w.sum()
36+
else:
37+
avg_dice = torch.mean(dice_score)
38+
dice_loss = 1.0 - avg_dice
39+
return dice_loss
40+
41+
class MultiScaleDiceLoss(nn.Module):
42+
def __init__(self, params):
43+
super(MultiScaleDiceLoss, self).__init__()
44+
self.enable_pix_weight = params['MultiScaleDiceLoss_Enable_Pixel_Weight'.lower()]
45+
self.enable_cls_weight = params['MultiScaleDiceLoss_Enable_Class_Weight'.lower()]
46+
self.multi_scale_weight = params['MultiScaleDiceLoss_Scale_Weight'.lower()]
47+
dice_params = {'DiceLoss_Enable_Pixel_Weight'.lower(): self.enable_pix_weight,
48+
'DiceLoss_Enable_Class_Weight'.lower(): self.enable_cls_weight}
49+
self.base_loss = DiceLoss(dice_params)
50+
51+
def forward(self, loss_input_dict):
52+
predict = loss_input_dict['prediction']
53+
soft_y = loss_input_dict['ground_truth']
54+
pix_w = loss_input_dict['pixel_weight']
55+
cls_w = loss_input_dict['class_weight']
56+
softmax = loss_input_dict['softmax']
57+
if(isinstance(predict, tuple) or isinstance(predict, list)):
58+
predict_num = len(predict)
59+
assert(predict_num == len(self.multi_scale_weight))
60+
loss = 0.0
61+
weight = 0.0
62+
interp_mode = 'trilinear' if(len(predict[0].shape) == 5) else 'bilinear'
63+
for i in range(predict_num):
64+
soft_y_temp = nn.functional.interpolate(soft_y,
65+
size = list(predict[i].shape)[2:], mode = interp_mode)
66+
if(pix_w is not None):
67+
pix_w_temp = nn.functional.interpolate(pix_w,
68+
size = list(predict[i].shape)[2:], mode = interp_mode)
69+
else:
70+
pix_w_temp = None
71+
temp_loss_dict = {}
72+
temp_loss_dict['prediction'] = predict[i]
73+
temp_loss_dict['ground_truth'] = soft_y_temp
74+
temp_loss_dict['pixel_weight'] = pix_w_temp
75+
temp_loss_dict['class_weight'] = cls_w
76+
temp_loss_dict['softmax'] = softmax
77+
temp_loss = self.base_loss(temp_loss_dict)
78+
loss = loss + temp_loss * self.multi_scale_weight[i]
79+
weight = weight + self.multi_scale_weight[i]
80+
loss = loss/weight
81+
else:
82+
loss = self.base_loss(loss_input_dict)
83+
return loss

pymic/loss/loss_factory.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
from pymic.loss.ce import CrossEntropyLoss
4+
from pymic.loss.dice import DiceLoss, MultiScaleDiceLoss
5+
6+
loss_dict = {'CrossEntropyLoss': CrossEntropyLoss,
7+
'DiceLoss': DiceLoss,
8+
'MultiScaleDiceLoss': MultiScaleDiceLoss}
9+
10+
def get_loss(params):
11+
loss_type = params['loss_type']
12+
if(loss_type in loss_dict):
13+
loss_obj = loss_dict[loss_type](params)
14+
else:
15+
raise ValueError("Undefined loss type {0:}".format(loss_type))
16+
return loss_obj

pymic/loss/util.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
import torch
5+
import torch.nn as nn
6+
import numpy as np
7+
8+
def get_soft_label(input_tensor, num_class, data_type = 'float'):
9+
"""
10+
convert a label tensor to soft label
11+
input_tensor: tensor with shae [B, 1, D, H, W]
12+
output_tensor: shape [B, num_class, D, H, W]
13+
"""
14+
tensor_list = []
15+
for i in range(num_class):
16+
temp_prob = input_tensor == i*torch.ones_like(input_tensor)
17+
tensor_list.append(temp_prob)
18+
output_tensor = torch.cat(tensor_list, dim = 1)
19+
if(data_type == 'float'):
20+
output_tensor = output_tensor.float()
21+
elif(data_type == 'double'):
22+
output_tensor = output_tensor.double()
23+
else:
24+
raise ValueError("data type can only be float and double: {0:}".format(data_type))
25+
26+
return output_tensor
27+
28+
def reshape_tensor_to_2D(x):
29+
"""
30+
reshape input variables of shape [B, C, D, H, W] to [voxel_n, C]
31+
"""
32+
tensor_dim = len(x.size())
33+
num_class = list(x.size())[1]
34+
if(tensor_dim == 5):
35+
x_perm = x.permute(0, 2, 3, 4, 1)
36+
elif(tensor_dim == 4):
37+
x_perm = x.permute(0, 2, 3, 1)
38+
else:
39+
raise ValueError("{0:}D tensor not supported".format(tensor_dim))
40+
41+
y = torch.reshape(x_perm, (-1, num_class))
42+
return y
43+
44+
def reshape_prediction_and_ground_truth(predict, soft_y):
45+
"""
46+
reshape input variables of shape [B, C, D, H, W] to [voxel_n, C]
47+
"""
48+
tensor_dim = len(predict.size())
49+
num_class = list(predict.size())[1]
50+
if(tensor_dim == 5):
51+
soft_y = soft_y.permute(0, 2, 3, 4, 1)
52+
predict = predict.permute(0, 2, 3, 4, 1)
53+
elif(tensor_dim == 4):
54+
soft_y = soft_y.permute(0, 2, 3, 1)
55+
predict = predict.permute(0, 2, 3, 1)
56+
else:
57+
raise ValueError("{0:}D tensor not supported".format(tensor_dim))
58+
59+
predict = torch.reshape(predict, (-1, num_class))
60+
soft_y = torch.reshape(soft_y, (-1, num_class))
61+
62+
return predict, soft_y
63+
64+
def get_classwise_dice(predict, soft_y, pix_w = None):
65+
"""
66+
get dice scores for each class in predict (after softmax) and soft_y
67+
"""
68+
69+
if(pix_w is None):
70+
y_vol = torch.sum(soft_y, dim = 0)
71+
p_vol = torch.sum(predict, dim = 0)
72+
intersect = torch.sum(soft_y * predict, dim = 0)
73+
else:
74+
y_vol = torch.sum(soft_y * pix_w, dim = 0)
75+
p_vol = torch.sum(predict * pix_w, dim = 0)
76+
intersect = torch.sum(soft_y * predict * pix_w, dim = 0)
77+
dice_score = (2.0 * intersect + 1e-5)/ (y_vol + p_vol + 1e-5)
78+
return dice_score

0 commit comments

Comments
 (0)