1+ # -*- coding: utf-8 -*-
2+ # Author: Guotai Wang
3+ # Date: 12 June, 2020
4+ # Implementation of of COPLENet for COVID-19 pneumonia lesion segmentation from CT images.
5+ # Reference:
6+ # G. Wang et al. A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions
7+ # from CT Images. IEEE Transactions on Medical Imaging, 2020. DOI:10.1109/TMI.2020.3000314.
8+
9+ from __future__ import print_function , division
10+ import torch
11+ import torch .nn as nn
12+
13+ class ConvLayer (nn .Module ):
14+ def __init__ (self , in_channels , out_channels , kernel_size = 1 ):
15+ super (ConvLayer , self ).__init__ ()
16+ padding = int ((kernel_size - 1 ) / 2 )
17+ self .conv = nn .Sequential (
18+ nn .Conv2d (in_channels , out_channels , kernel_size = kernel_size , padding = padding ),
19+ nn .BatchNorm2d (out_channels ),
20+ nn .LeakyReLU ()
21+ )
22+
23+ def forward (self , x ):
24+ return self .conv (x )
25+
26+ class SEBlock (nn .Module ):
27+ def __init__ (self , in_channels , r ):
28+ super (SEBlock , self ).__init__ ()
29+
30+ redu_chns = int (in_channels / r )
31+ self .se_layers = nn .Sequential (
32+ nn .AdaptiveAvgPool2d (1 ),
33+ nn .Conv2d (in_channels , redu_chns , kernel_size = 1 , padding = 0 ),
34+ nn .LeakyReLU (),
35+ nn .Conv2d (redu_chns , in_channels , kernel_size = 1 , padding = 0 ),
36+ nn .ReLU ())
37+
38+ def forward (self , x ):
39+ f = self .se_layers (x )
40+ return f * x + x
41+
42+ class ASPPBlock (nn .Module ):
43+ def __init__ (self ,in_channels , out_channels_list , kernel_size_list , dilation_list ):
44+ super (ASPPBlock , self ).__init__ ()
45+ self .conv_num = len (out_channels_list )
46+ assert (self .conv_num == 4 )
47+ assert (self .conv_num == len (kernel_size_list ) and self .conv_num == len (dilation_list ))
48+ pad0 = int ((kernel_size_list [0 ] - 1 ) / 2 * dilation_list [0 ])
49+ pad1 = int ((kernel_size_list [1 ] - 1 ) / 2 * dilation_list [1 ])
50+ pad2 = int ((kernel_size_list [2 ] - 1 ) / 2 * dilation_list [2 ])
51+ pad3 = int ((kernel_size_list [3 ] - 1 ) / 2 * dilation_list [3 ])
52+ self .conv_1 = nn .Conv2d (in_channels , out_channels_list [0 ], kernel_size = kernel_size_list [0 ],
53+ dilation = dilation_list [0 ], padding = pad0 )
54+ self .conv_2 = nn .Conv2d (in_channels , out_channels_list [1 ], kernel_size = kernel_size_list [1 ],
55+ dilation = dilation_list [1 ], padding = pad1 )
56+ self .conv_3 = nn .Conv2d (in_channels , out_channels_list [2 ], kernel_size = kernel_size_list [2 ],
57+ dilation = dilation_list [2 ], padding = pad2 )
58+ self .conv_4 = nn .Conv2d (in_channels , out_channels_list [3 ], kernel_size = kernel_size_list [3 ],
59+ dilation = dilation_list [3 ], padding = pad3 )
60+
61+ out_channels = out_channels_list [0 ] + out_channels_list [1 ] + out_channels_list [2 ] + out_channels_list [3 ]
62+ self .conv_1x1 = nn .Sequential (
63+ nn .Conv2d (out_channels , out_channels , kernel_size = 1 , padding = 0 ),
64+ nn .BatchNorm2d (out_channels ),
65+ nn .LeakyReLU ())
66+
67+ def forward (self , x ):
68+ x1 = self .conv_1 (x )
69+ x2 = self .conv_2 (x )
70+ x3 = self .conv_3 (x )
71+ x4 = self .conv_4 (x )
72+
73+ y = torch .cat ([x1 , x2 , x3 , x4 ], dim = 1 )
74+ y = self .conv_1x1 (y )
75+ return y
76+
77+ class ConvBNActBlock (nn .Module ):
78+ """Two convolution layers with batch norm, leaky relu, dropout and SE block"""
79+ def __init__ (self ,in_channels , out_channels , dropout_p ):
80+ super (ConvBNActBlock , self ).__init__ ()
81+ self .conv_conv = nn .Sequential (
82+ nn .Conv2d (in_channels , out_channels , kernel_size = 3 , padding = 1 ),
83+ nn .BatchNorm2d (out_channels ),
84+ nn .LeakyReLU (),
85+ nn .Dropout (dropout_p ),
86+ nn .Conv2d (out_channels , out_channels , kernel_size = 3 , padding = 1 ),
87+ nn .BatchNorm2d (out_channels ),
88+ nn .LeakyReLU (),
89+ SEBlock (out_channels , 2 )
90+ )
91+
92+ def forward (self , x ):
93+ return self .conv_conv (x )
94+
95+ class DownBlock (nn .Module ):
96+ """Downsampling by a concantenation of max-pool and avg-pool, followed by ConvBNActBlock
97+ """
98+ def __init__ (self , in_channels , out_channels , dropout_p ):
99+ super (DownBlock , self ).__init__ ()
100+ self .maxpool = nn .MaxPool2d (2 )
101+ self .avgpool = nn .AvgPool2d (2 )
102+ self .conv = ConvBNActBlock (2 * in_channels , out_channels , dropout_p )
103+
104+ def forward (self , x ):
105+ x_max = self .maxpool (x )
106+ x_avg = self .avgpool (x )
107+ x_cat = torch .cat ([x_max , x_avg ], dim = 1 )
108+ y = self .conv (x_cat )
109+ return y + x_cat
110+
111+ class UpBlock (nn .Module ):
112+ """Upssampling followed by ConvBNActBlock"""
113+ def __init__ (self , in_channels1 , in_channels2 , out_channels ,
114+ bilinear = True , dropout_p = 0.5 ):
115+ super (UpBlock , self ).__init__ ()
116+ self .bilinear = bilinear
117+ if bilinear :
118+ self .conv1x1 = nn .Conv2d (in_channels1 , in_channels2 , kernel_size = 1 )
119+ self .up = nn .Upsample (scale_factor = 2 , mode = 'bilinear' , align_corners = True )
120+ else :
121+ self .up = nn .ConvTranspose2d (in_channels1 , in_channels2 , kernel_size = 2 , stride = 2 )
122+ self .conv = ConvBNActBlock (in_channels2 * 2 , out_channels , dropout_p )
123+
124+ def forward (self , x1 , x2 ):
125+ if self .bilinear :
126+ x1 = self .conv1x1 (x1 )
127+ x1 = self .up (x1 )
128+ x_cat = torch .cat ([x2 , x1 ], dim = 1 )
129+ y = self .conv (x_cat )
130+ return y + x_cat
131+
132+ class COPLENet (nn .Module ):
133+ def __init__ (self , params ):
134+ super (COPLENet , self ).__init__ ()
135+ self .params = params
136+ self .in_chns = self .params ['in_chns' ]
137+ self .ft_chns = self .params ['feature_chns' ]
138+ self .n_class = self .params ['class_num' ]
139+ self .bilinear = self .params ['bilinear' ]
140+ self .dropout = self .params ['dropout' ]
141+ assert (len (self .ft_chns ) == 5 )
142+
143+ f0_half = int (self .ft_chns [0 ] / 2 )
144+ f1_half = int (self .ft_chns [1 ] / 2 )
145+ f2_half = int (self .ft_chns [2 ] / 2 )
146+ f3_half = int (self .ft_chns [3 ] / 2 )
147+ self .in_conv = ConvBNActBlock (self .in_chns , self .ft_chns [0 ], self .dropout [0 ])
148+ self .down1 = DownBlock (self .ft_chns [0 ], self .ft_chns [1 ], self .dropout [1 ])
149+ self .down2 = DownBlock (self .ft_chns [1 ], self .ft_chns [2 ], self .dropout [2 ])
150+ self .down3 = DownBlock (self .ft_chns [2 ], self .ft_chns [3 ], self .dropout [3 ])
151+ self .down4 = DownBlock (self .ft_chns [3 ], self .ft_chns [4 ], self .dropout [4 ])
152+
153+ self .bridge0 = ConvLayer (self .ft_chns [0 ], f0_half )
154+ self .bridge1 = ConvLayer (self .ft_chns [1 ], f1_half )
155+ self .bridge2 = ConvLayer (self .ft_chns [2 ], f2_half )
156+ self .bridge3 = ConvLayer (self .ft_chns [3 ], f3_half )
157+
158+ self .up1 = UpBlock (self .ft_chns [4 ], f3_half , self .ft_chns [3 ], dropout_p = self .dropout [3 ])
159+ self .up2 = UpBlock (self .ft_chns [3 ], f2_half , self .ft_chns [2 ], dropout_p = self .dropout [2 ])
160+ self .up3 = UpBlock (self .ft_chns [2 ], f1_half , self .ft_chns [1 ], dropout_p = self .dropout [1 ])
161+ self .up4 = UpBlock (self .ft_chns [1 ], f0_half , self .ft_chns [0 ], dropout_p = self .dropout [0 ])
162+
163+ f4 = self .ft_chns [4 ]
164+ aspp_chns = [int (f4 / 4 ), int (f4 / 4 ), int (f4 / 4 ), int (f4 / 4 )]
165+ aspp_knls = [1 , 3 , 3 , 3 ]
166+ aspp_dila = [1 , 2 , 4 , 6 ]
167+ self .aspp = ASPPBlock (f4 , aspp_chns , aspp_knls , aspp_dila )
168+
169+
170+ self .out_conv = nn .Conv2d (self .ft_chns [0 ], self .n_class ,
171+ kernel_size = 3 , padding = 1 )
172+
173+ def forward (self , x ):
174+ x_shape = list (x .shape )
175+ if (len (x_shape ) == 5 ):
176+ [N , C , D , H , W ] = x_shape
177+ new_shape = [N * D , C , H , W ]
178+ x = torch .transpose (x , 1 , 2 )
179+ x = torch .reshape (x , new_shape )
180+ x0 = self .in_conv (x )
181+ x0b = self .bridge0 (x0 )
182+ x1 = self .down1 (x0 )
183+ x1b = self .bridge1 (x1 )
184+ x2 = self .down2 (x1 )
185+ x2b = self .bridge2 (x2 )
186+ x3 = self .down3 (x2 )
187+ x3b = self .bridge3 (x3 )
188+ x4 = self .down4 (x3 )
189+ x4 = self .aspp (x4 )
190+
191+ x = self .up1 (x4 , x3b )
192+ x = self .up2 (x , x2b )
193+ x = self .up3 (x , x1b )
194+ x = self .up4 (x , x0b )
195+ output = self .out_conv (x )
196+
197+ if (len (x_shape ) == 5 ):
198+ new_shape = [N , D ] + list (output .shape )[1 :]
199+ output = torch .reshape (output , new_shape )
200+ output = torch .transpose (output , 1 , 2 )
201+ return output
0 commit comments