Skip to content

Commit 4ef1335

Browse files
committed
update network
1 parent 363cca9 commit 4ef1335

File tree

16 files changed

+350
-253
lines changed

16 files changed

+350
-253
lines changed

pymic/loss/ce.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,44 @@ def forward(self, loss_input_dict):
3838
else:
3939
ce = torch.mean(ce)
4040
return ce
41+
42+
class GeneralizedCrossEntropyLoss(nn.Module):
43+
"""
44+
Generalized cross entropy loss to deal with noisy labels.
45+
Z. Zhang et al. Generalized Cross Entropy Loss for Training Deep Neural Networks
46+
with Noisy Labels, NeurIPS 2018.
47+
"""
48+
def __init__(self, params):
49+
super(GeneralizedCrossEntropyLoss, self).__init__()
50+
self.enable_pix_weight = params['GeneralizedCrossEntropyLoss_Enable_Pixel_Weight'.lower()]
51+
self.enable_cls_weight = params['GeneralizedCrossEntropyLoss_Enable_Class_Weight'.lower()]
52+
self.q = params['GeneralizedCrossEntropyLoss_q'.lower()]
53+
54+
def forward(self, loss_input_dict):
55+
predict = loss_input_dict['prediction']
56+
soft_y = loss_input_dict['ground_truth']
57+
pix_w = loss_input_dict['pixel_weight']
58+
cls_w = loss_input_dict['class_weight']
59+
softmax = loss_input_dict['softmax']
60+
61+
if(softmax):
62+
predict = nn.Softmax(dim = 1)(predict)
63+
predict = reshape_tensor_to_2D(predict)
64+
soft_y = reshape_tensor_to_2D(soft_y)
65+
gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y
66+
67+
if(self.enable_cls_weight):
68+
if(cls_w is None):
69+
raise ValueError("Class weight is enabled but not defined")
70+
gce = torch.sum(gce * cls_w, dim = 1)
71+
else:
72+
gce = torch.sum(gce, dim = 1)
73+
74+
if(self.enable_pix_weight):
75+
if(pix_w is None):
76+
raise ValueError("Pixel weight is enabled but not defined")
77+
pix_w = reshape_tensor_to_2D(pix_w)
78+
gce = torch.sum(gce * pix_w) / torch.sum(pix_w)
79+
else:
80+
gce = torch.mean(gce)
81+
return gce

pymic/loss/dice.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
import torch.nn as nn
6+
from pymic.loss.ce import CrossEntropyLoss
67
from pymic.loss.util import reshape_tensor_to_2D, get_classwise_dice
78

89
class DiceLoss(nn.Module):
@@ -38,14 +39,33 @@ def forward(self, loss_input_dict):
3839
dice_loss = 1.0 - avg_dice
3940
return dice_loss
4041

42+
class DiceWithCrossEntropyLoss(nn.Module):
43+
def __init__(self, params):
44+
super(DiceWithCrossEntropyLoss, self).__init__()
45+
self.enable_pix_weight = params['DiceWithCrossEntropyLoss_Enable_Pixel_Weight'.lower()]
46+
self.enable_cls_weight = params['DiceWithCrossEntropyLoss_Enable_Class_Weight'.lower()]
47+
self.ce_weight = params['DiceWithCrossEntropyLoss_CE_Weight'.lower()]
48+
dice_params = {'DiceLoss_Enable_Pixel_Weight'.lower(): self.enable_pix_weight,
49+
'DiceLoss_Enable_Class_Weight'.lower(): self.enable_cls_weight}
50+
ce_params = {'CrossEntropyLoss_Enable_Pixel_Weight'.lower(): self.enable_pix_weight,
51+
'CrossEntropyLoss_Enable_Class_Weight'.lower(): self.enable_cls_weight}
52+
self.dice_loss = DiceLoss(dice_params)
53+
self.ce_loss = CrossEntropyLoss(ce_params)
54+
55+
def forward(self, loss_input_dict):
56+
loss1 = self.dice_loss(loss_input_dict)
57+
loss2 = self.ce_loss(loss_input_dict)
58+
loss = loss1 + self.ce_weight * loss2
59+
return loss
60+
4161
class MultiScaleDiceLoss(nn.Module):
4262
def __init__(self, params):
4363
super(MultiScaleDiceLoss, self).__init__()
4464
self.enable_pix_weight = params['MultiScaleDiceLoss_Enable_Pixel_Weight'.lower()]
4565
self.enable_cls_weight = params['MultiScaleDiceLoss_Enable_Class_Weight'.lower()]
4666
self.multi_scale_weight = params['MultiScaleDiceLoss_Scale_Weight'.lower()]
4767
dice_params = {'DiceLoss_Enable_Pixel_Weight'.lower(): self.enable_pix_weight,
48-
'DiceLoss_Enable_Class_Weight'.lower(): self.enable_cls_weight}
68+
'DiceLoss_Enable_Class_Weight'.lower(): self.enable_cls_weight}
4969
self.base_loss = DiceLoss(dice_params)
5070

5171
def forward(self, loss_input_dict):
@@ -80,4 +100,50 @@ def forward(self, loss_input_dict):
80100
loss = loss/weight
81101
else:
82102
loss = self.base_loss(loss_input_dict)
83-
return loss
103+
return loss
104+
105+
class NoiseRobustDiceLoss(nn.Module):
106+
"""
107+
Noise-robust Dice loss according to the following paper.
108+
G. Wang et al. A Noise-Robust Framework for Automatic Segmentation of COVID-19
109+
Pneumonia Lesions From CT Images, IEEE TMI, 2020.
110+
"""
111+
def __init__(self, params):
112+
super(NoiseRobustDiceLoss, self).__init__()
113+
self.enable_pix_weight = params['NoiseRobustDiceLoss_Enable_Pixel_Weight'.lower()]
114+
self.enable_cls_weight = params['NoiseRobustDiceLoss_Enable_Class_Weight'.lower()]
115+
self.gamma = params['NoiseRobustDiceLoss_gamma'.lower()]
116+
117+
def forward(self, loss_input_dict):
118+
predict = loss_input_dict['prediction']
119+
soft_y = loss_input_dict['ground_truth']
120+
pix_w = loss_input_dict['pixel_weight']
121+
cls_w = loss_input_dict['class_weight']
122+
softmax = loss_input_dict['softmax']
123+
124+
if(softmax):
125+
predict = nn.Softmax(dim = 1)(predict)
126+
predict = reshape_tensor_to_2D(predict)
127+
soft_y = reshape_tensor_to_2D(soft_y)
128+
129+
numerator = torch.abs(predict - soft_y)
130+
numerator = torch.pow(numerator, self.gamma)
131+
denominator = predict + soft_y
132+
if(self.enable_pix_weight):
133+
if(pix_w is None):
134+
raise ValueError("Pixel weight is enabled but not defined")
135+
pix_w = reshape_tensor_to_2D(pix_w)
136+
numerator = numerator * pix_w
137+
denominator = denominator * pix_w
138+
numer_sum = torch.sum(numerator, dim = 0)
139+
denom_sum = torch.sum(denominator, dim = 0)
140+
loss_vector = numer_sum / (denom_sum + 1e-5)
141+
142+
if(self.enable_cls_weight):
143+
if(cls_w is None):
144+
raise ValueError("Class weight is enabled but not defined")
145+
weighted_dice = loss_vector * cls_w
146+
loss = weighted_dice.sum() / cls_w.sum()
147+
else:
148+
loss = torch.mean(loss_vector)
149+
return loss

pymic/loss/loss_factory.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import print_function, division
3-
from pymic.loss.ce import CrossEntropyLoss
3+
from pymic.loss.ce import CrossEntropyLoss, GeneralizedCrossEntropyLoss
44
from pymic.loss.dice import DiceLoss, MultiScaleDiceLoss
5+
from pymic.loss.dice import DiceWithCrossEntropyLoss, NoiseRobustDiceLoss
6+
from pymic.loss.exp_log import ExpLogLoss
57

68
loss_dict = {'CrossEntropyLoss': CrossEntropyLoss,
9+
'GeneralizedCrossEntropyLoss': GeneralizedCrossEntropyLoss,
710
'DiceLoss': DiceLoss,
8-
'MultiScaleDiceLoss': MultiScaleDiceLoss}
11+
'MultiScaleDiceLoss': MultiScaleDiceLoss,
12+
'DiceWithCrossEntropyLoss': DiceWithCrossEntropyLoss,
13+
'NoiseRobustDiceLoss': NoiseRobustDiceLoss,
14+
'ExpLogLoss': ExpLogLoss}
915

1016
def get_loss(params):
1117
loss_type = params['loss_type']

pymic/net/net2d/cople_net.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# -*- coding: utf-8 -*-
2+
# Author: Guotai Wang
3+
# Date: 12 June, 2020
4+
# Implementation of of COPLENet for COVID-19 pneumonia lesion segmentation from CT images.
5+
# Reference:
6+
# G. Wang et al. A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions
7+
# from CT Images. IEEE Transactions on Medical Imaging, 2020. DOI:10.1109/TMI.2020.3000314.
8+
9+
from __future__ import print_function, division
10+
import torch
11+
import torch.nn as nn
12+
13+
class ConvLayer(nn.Module):
14+
def __init__(self, in_channels, out_channels, kernel_size = 1):
15+
super(ConvLayer, self).__init__()
16+
padding = int((kernel_size - 1) / 2)
17+
self.conv = nn.Sequential(
18+
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
19+
nn.BatchNorm2d(out_channels),
20+
nn.LeakyReLU()
21+
)
22+
23+
def forward(self, x):
24+
return self.conv(x)
25+
26+
class SEBlock(nn.Module):
27+
def __init__(self, in_channels, r):
28+
super(SEBlock, self).__init__()
29+
30+
redu_chns = int(in_channels / r)
31+
self.se_layers = nn.Sequential(
32+
nn.AdaptiveAvgPool2d(1),
33+
nn.Conv2d(in_channels, redu_chns, kernel_size=1, padding=0),
34+
nn.LeakyReLU(),
35+
nn.Conv2d(redu_chns, in_channels, kernel_size=1, padding=0),
36+
nn.ReLU())
37+
38+
def forward(self, x):
39+
f = self.se_layers(x)
40+
return f*x + x
41+
42+
class ASPPBlock(nn.Module):
43+
def __init__(self,in_channels, out_channels_list, kernel_size_list, dilation_list):
44+
super(ASPPBlock, self).__init__()
45+
self.conv_num = len(out_channels_list)
46+
assert(self.conv_num == 4)
47+
assert(self.conv_num == len(kernel_size_list) and self.conv_num == len(dilation_list))
48+
pad0 = int((kernel_size_list[0] - 1) / 2 * dilation_list[0])
49+
pad1 = int((kernel_size_list[1] - 1) / 2 * dilation_list[1])
50+
pad2 = int((kernel_size_list[2] - 1) / 2 * dilation_list[2])
51+
pad3 = int((kernel_size_list[3] - 1) / 2 * dilation_list[3])
52+
self.conv_1 = nn.Conv2d(in_channels, out_channels_list[0], kernel_size = kernel_size_list[0],
53+
dilation = dilation_list[0], padding = pad0 )
54+
self.conv_2 = nn.Conv2d(in_channels, out_channels_list[1], kernel_size = kernel_size_list[1],
55+
dilation = dilation_list[1], padding = pad1 )
56+
self.conv_3 = nn.Conv2d(in_channels, out_channels_list[2], kernel_size = kernel_size_list[2],
57+
dilation = dilation_list[2], padding = pad2 )
58+
self.conv_4 = nn.Conv2d(in_channels, out_channels_list[3], kernel_size = kernel_size_list[3],
59+
dilation = dilation_list[3], padding = pad3 )
60+
61+
out_channels = out_channels_list[0] + out_channels_list[1] + out_channels_list[2] + out_channels_list[3]
62+
self.conv_1x1 = nn.Sequential(
63+
nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0),
64+
nn.BatchNorm2d(out_channels),
65+
nn.LeakyReLU())
66+
67+
def forward(self, x):
68+
x1 = self.conv_1(x)
69+
x2 = self.conv_2(x)
70+
x3 = self.conv_3(x)
71+
x4 = self.conv_4(x)
72+
73+
y = torch.cat([x1, x2, x3, x4], dim=1)
74+
y = self.conv_1x1(y)
75+
return y
76+
77+
class ConvBNActBlock(nn.Module):
78+
"""Two convolution layers with batch norm, leaky relu, dropout and SE block"""
79+
def __init__(self,in_channels, out_channels, dropout_p):
80+
super(ConvBNActBlock, self).__init__()
81+
self.conv_conv = nn.Sequential(
82+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
83+
nn.BatchNorm2d(out_channels),
84+
nn.LeakyReLU(),
85+
nn.Dropout(dropout_p),
86+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
87+
nn.BatchNorm2d(out_channels),
88+
nn.LeakyReLU(),
89+
SEBlock(out_channels, 2)
90+
)
91+
92+
def forward(self, x):
93+
return self.conv_conv(x)
94+
95+
class DownBlock(nn.Module):
96+
"""Downsampling by a concantenation of max-pool and avg-pool, followed by ConvBNActBlock
97+
"""
98+
def __init__(self, in_channels, out_channels, dropout_p):
99+
super(DownBlock, self).__init__()
100+
self.maxpool = nn.MaxPool2d(2)
101+
self.avgpool = nn.AvgPool2d(2)
102+
self.conv = ConvBNActBlock(2 * in_channels, out_channels, dropout_p)
103+
104+
def forward(self, x):
105+
x_max = self.maxpool(x)
106+
x_avg = self.avgpool(x)
107+
x_cat = torch.cat([x_max, x_avg], dim=1)
108+
y = self.conv(x_cat)
109+
return y + x_cat
110+
111+
class UpBlock(nn.Module):
112+
"""Upssampling followed by ConvBNActBlock"""
113+
def __init__(self, in_channels1, in_channels2, out_channels,
114+
bilinear=True, dropout_p = 0.5):
115+
super(UpBlock, self).__init__()
116+
self.bilinear = bilinear
117+
if bilinear:
118+
self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1)
119+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
120+
else:
121+
self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2)
122+
self.conv = ConvBNActBlock(in_channels2 * 2, out_channels, dropout_p)
123+
124+
def forward(self, x1, x2):
125+
if self.bilinear:
126+
x1 = self.conv1x1(x1)
127+
x1 = self.up(x1)
128+
x_cat = torch.cat([x2, x1], dim=1)
129+
y = self.conv(x_cat)
130+
return y + x_cat
131+
132+
class COPLENet(nn.Module):
133+
def __init__(self, params):
134+
super(COPLENet, self).__init__()
135+
self.params = params
136+
self.in_chns = self.params['in_chns']
137+
self.ft_chns = self.params['feature_chns']
138+
self.n_class = self.params['class_num']
139+
self.bilinear = self.params['bilinear']
140+
self.dropout = self.params['dropout']
141+
assert(len(self.ft_chns) == 5)
142+
143+
f0_half = int(self.ft_chns[0] / 2)
144+
f1_half = int(self.ft_chns[1] / 2)
145+
f2_half = int(self.ft_chns[2] / 2)
146+
f3_half = int(self.ft_chns[3] / 2)
147+
self.in_conv= ConvBNActBlock(self.in_chns, self.ft_chns[0], self.dropout[0])
148+
self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1])
149+
self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2])
150+
self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3])
151+
self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4])
152+
153+
self.bridge0= ConvLayer(self.ft_chns[0], f0_half)
154+
self.bridge1= ConvLayer(self.ft_chns[1], f1_half)
155+
self.bridge2= ConvLayer(self.ft_chns[2], f2_half)
156+
self.bridge3= ConvLayer(self.ft_chns[3], f3_half)
157+
158+
self.up1 = UpBlock(self.ft_chns[4], f3_half, self.ft_chns[3], dropout_p = self.dropout[3])
159+
self.up2 = UpBlock(self.ft_chns[3], f2_half, self.ft_chns[2], dropout_p = self.dropout[2])
160+
self.up3 = UpBlock(self.ft_chns[2], f1_half, self.ft_chns[1], dropout_p = self.dropout[1])
161+
self.up4 = UpBlock(self.ft_chns[1], f0_half, self.ft_chns[0], dropout_p = self.dropout[0])
162+
163+
f4 = self.ft_chns[4]
164+
aspp_chns = [int(f4 / 4), int(f4 / 4), int(f4 / 4), int(f4 / 4)]
165+
aspp_knls = [1, 3, 3, 3]
166+
aspp_dila = [1, 2, 4, 6]
167+
self.aspp = ASPPBlock(f4, aspp_chns, aspp_knls, aspp_dila)
168+
169+
170+
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,
171+
kernel_size = 3, padding = 1)
172+
173+
def forward(self, x):
174+
x_shape = list(x.shape)
175+
if(len(x_shape) == 5):
176+
[N, C, D, H, W] = x_shape
177+
new_shape = [N*D, C, H, W]
178+
x = torch.transpose(x, 1, 2)
179+
x = torch.reshape(x, new_shape)
180+
x0 = self.in_conv(x)
181+
x0b = self.bridge0(x0)
182+
x1 = self.down1(x0)
183+
x1b = self.bridge1(x1)
184+
x2 = self.down2(x1)
185+
x2b = self.bridge2(x2)
186+
x3 = self.down3(x2)
187+
x3b = self.bridge3(x3)
188+
x4 = self.down4(x3)
189+
x4 = self.aspp(x4)
190+
191+
x = self.up1(x4, x3b)
192+
x = self.up2(x, x2b)
193+
x = self.up3(x, x1b)
194+
x = self.up4(x, x0b)
195+
output = self.out_conv(x)
196+
197+
if(len(x_shape) == 5):
198+
new_shape = [N, D] + list(output.shape)[1:]
199+
output = torch.reshape(output, new_shape)
200+
output = torch.transpose(output, 1, 2)
201+
return output
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,7 @@ def forward(self, x):
117117
'feature_chns':[2, 8, 32, 48, 64],
118118
'dropout': [0, 0, 0.3, 0.4, 0.5],
119119
'class_num': 2,
120-
'bilinear': True,
121-
'acti_func': 'relu'}
120+
'bilinear': True}
122121
Net = UNet2D_ScSE(params)
123122
Net = Net.double()
124123

pymic/net/net_factory.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
from pymic.net.net2d.unet2d import UNet2D
4+
from pymic.net.net2d.cople_net import COPLENet
5+
from pymic.net.net2d.unet2d_scse import UNet2D_ScSE
6+
from pymic.net.net3d.unet2d5 import UNet2D5
7+
from pymic.net.net3d.unet3d import UNet3D
8+
9+
net_dict = {
10+
'UNet2D': UNet2D,
11+
'COPLENet': COPLENet,
12+
'UNet2D_ScSE': UNet2D_ScSE,
13+
'UNet2D5': UNet2D5,
14+
'UNet3D': UNet3D
15+
}
16+
17+
def get_network(params):
18+
net_type = params['net_type']
19+
if(net_type in net_dict):
20+
net_obj = net_dict[net_type](params)
21+
else:
22+
raise ValueError("Undefined network type {0:}".format(net_type))
23+
return net_obj

pymic/net2d/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)