Skip to content

Commit 50a82ef

Browse files
committed
update docs for 3D networks
update docs for 3D networks
1 parent 451a85c commit 50a82ef

File tree

13 files changed

+420
-228
lines changed

13 files changed

+420
-228
lines changed

pymic/layer/convolution.py

Lines changed: 33 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,22 @@
77
class ConvolutionLayer(nn.Module):
88
"""
99
A compose layer with the following components:
10-
convolution -> (batch_norm / layer_norm / group_norm / instance_norm) -> activation -> (dropout)
11-
batch norm and dropout are optional
10+
convolution -> (batch_norm / layer_norm / group_norm / instance_norm) -> (activation) -> (dropout)
11+
Batch norm and activation are optional.
12+
13+
:param in_channels: (int) The input channel number.
14+
:param out_channels: (int) The output channel number.
15+
:param kernel_size: The size of convolution kernel. It can be either a single
16+
int or a tupe of two or three ints.
17+
:param dim: (int) The dimention of convolution (2 or 3).
18+
:param stride: (int) The stride of convolution.
19+
:param padding: (int) Padding size.
20+
:param dilation: (int) Dilation rate.
21+
:param conv_group: (int) The groupt number of convolution.
22+
:param bias: (bool) Add bias or not for convolution.
23+
:param norm_type: (str or None) Normalization type, can be `batch_norm`, 'group_norm'.
24+
:param norm_group: (int) The number of group for group normalization.
25+
:param acti_func: (str or None) Activation funtion.
1226
"""
1327
def __init__(self, in_channels, out_channels, kernel_size, dim = 3,
1428
stride = 1, padding = 0, dilation = 1, conv_group = 1, bias = True,
@@ -50,9 +64,23 @@ def forward(self, x):
5064

5165
class DepthSeperableConvolutionLayer(nn.Module):
5266
"""
53-
A compose layer with the following components:
54-
convolution -> (batch_norm) -> activation -> (dropout)
55-
batch norm and dropout are optional
67+
Depth seperable convolution with the following components:
68+
1x1 conv -> group conv -> (batch_norm / layer_norm / group_norm / instance_norm) -> (activation) -> (dropout)
69+
Batch norm and activation are optional.
70+
71+
:param in_channels: (int) The input channel number.
72+
:param out_channels: (int) The output channel number.
73+
:param kernel_size: The size of convolution kernel. It can be either a single
74+
int or a tupe of two or three ints.
75+
:param dim: (int) The dimention of convolution (2 or 3).
76+
:param stride: (int) The stride of convolution.
77+
:param padding: (int) Padding size.
78+
:param dilation: (int) Dilation rate.
79+
:param conv_group: (int) The groupt number of convolution.
80+
:param bias: (bool) Add bias or not for convolution.
81+
:param norm_type: (str or None) Normalization type, can be `batch_norm`, 'group_norm'.
82+
:param norm_group: (int) The number of group for group normalization.
83+
:param acti_func: (str or None) Activation funtion.
5684
"""
5785
def __init__(self, in_channels, out_channels, kernel_size, dim = 3,
5886
stride = 1, padding = 0, dilation =1, conv_group = 1, bias = True,
@@ -97,68 +125,3 @@ def forward(self, x):
97125
f = self.acti_func(f)
98126
return f
99127

100-
class ConvolutionSepAll3DLayer(nn.Module):
101-
"""
102-
A compose layer with the following components:
103-
convolution -> (batch_norm) -> activation -> (dropout)
104-
batch norm and dropout are optional
105-
"""
106-
def __init__(self, in_channels, out_channels, kernel_size, dim = 3,
107-
stride = 1, padding = 0, dilation =1, groups = 1, bias = True,
108-
batch_norm = True, acti_func = None):
109-
super(ConvolutionSepAll3DLayer, self).__init__()
110-
self.n_in_chns = in_channels
111-
self.n_out_chns = out_channels
112-
self.batch_norm = batch_norm
113-
self.acti_func = acti_func
114-
115-
assert(dim == 3)
116-
chn = min(in_channels, out_channels)
117-
118-
self.conv_intra_plane1 = nn.Conv2d(chn, chn,
119-
kernel_size, stride, padding, dilation, chn, bias)
120-
121-
self.conv_intra_plane2 = nn.Conv2d(chn, chn,
122-
kernel_size, stride, padding, dilation, chn, bias)
123-
124-
self.conv_intra_plane3 = nn.Conv2d(chn, chn,
125-
kernel_size, stride, padding, dilation, chn, bias)
126-
127-
self.conv_space_wise = nn.Conv2d(in_channels, out_channels,
128-
1, stride, 0, dilation, 1, bias)
129-
130-
if(self.batch_norm):
131-
self.bn = nn.BatchNorm3d(out_channels)
132-
133-
def forward(self, x):
134-
in_shape = list(x.shape)
135-
assert(len(in_shape) == 5)
136-
[B, C, D, H, W] = in_shape
137-
f0 = x.permute(0, 2, 1, 3, 4) #[B, D, C, H, W]
138-
f0 = f0.contiguous().view([B*D, C, H, W])
139-
140-
Cc = min(self.n_in_chns, self.n_out_chns)
141-
Co = self.n_out_chns
142-
if(self.n_in_chns > self.n_out_chns):
143-
f0 = self.conv_space_wise(f0) #[B*D, Cc, H, W]
144-
145-
f1 = self.conv_intra_plane1(f0)
146-
f2 = f1.contiguous().view([B, D, Cc, H, W])
147-
f2 = f2.permute(0, 3, 2, 1, 4) #[B, H, Cc, D, W]
148-
f2 = f2.contiguous().view([B*H, Cc, D, W])
149-
f2 = self.conv_intra_plane2(f2)
150-
f3 = f2.contiguous().view([B, H, Cc, D, W])
151-
f3 = f3.permute(0, 4, 2, 3, 1) #[B, W, Cc, D, H]
152-
f3 = f3.contiguous().view([B*W, Cc, D, H])
153-
f3 = self.conv_intra_plane3(f3)
154-
if(self.n_in_chns <= self.n_out_chns):
155-
f3 = self.conv_space_wise(f3) #[B*W, Co, D, H]
156-
157-
f3 = f3.contiguous().view([B, W, Co, D, H])
158-
f3 = f3.permute([0, 2, 3, 4, 1]) #[B, Co, D, H, W]
159-
160-
if(self.batch_norm):
161-
f3 = self.bn(f3)
162-
if(self.acti_func is not None):
163-
f3 = self.acti_func(f3)
164-
return f3

pymic/layer/deconvolution.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,21 @@
77
class DeconvolutionLayer(nn.Module):
88
"""
99
A compose layer with the following components:
10-
deconvolution -> (batch_norm) -> activation -> (dropout)
11-
batch norm and dropout are optional
10+
deconvolution -> (batch_norm / layer_norm / group_norm / instance_norm) -> (activation) -> (dropout)
11+
Batch norm and activation are optional.
12+
13+
:param in_channels: (int) The input channel number.
14+
:param out_channels: (int) The output channel number.
15+
:param kernel_size: The size of convolution kernel. It can be either a single
16+
int or a tupe of two or three ints.
17+
:param dim: (int) The dimention of convolution (2 or 3).
18+
:param stride: (int) The stride of convolution.
19+
:param padding: (int) Padding size.
20+
:param dilation: (int) Dilation rate.
21+
:param groups: (int) The groupt number of convolution.
22+
:param bias: (bool) Add bias or not for convolution.
23+
:param batch_norm: (bool) Use batch norm or not.
24+
:param acti_func: (str or None) Activation funtion.
1225
"""
1326
def __init__(self, in_channels, out_channels, kernel_size,
1427
dim = 3, stride = 1, padding = 0, output_padding = 0,
@@ -44,9 +57,23 @@ def forward(self, x):
4457

4558
class DepthSeperableDeconvolutionLayer(nn.Module):
4659
"""
47-
A compose layer with the following components:
48-
convolution -> (batch_norm) -> activation -> (dropout)
49-
batch norm and dropout are optional
60+
Depth seperable deconvolution with the following components:
61+
1x1 conv -> deconv -> (batch_norm / layer_norm / group_norm / instance_norm) -> (activation) -> (dropout)
62+
Batch norm and activation are optional.
63+
64+
:param in_channels: (int) The input channel number.
65+
:param out_channels: (int) The output channel number.
66+
:param kernel_size: The size of convolution kernel. It can be either a single
67+
int or a tupe of two or three ints.
68+
:param dim: (int) The dimention of convolution (2 or 3).
69+
:param stride: (int) The stride of convolution.
70+
:param padding: (int) Padding size for input.
71+
:param output_padding: (int) Padding size for ouput.
72+
:param dilation: (int) Dilation rate.
73+
:param groups: (int) The groupt number of convolution.
74+
:param bias: (bool) Add bias or not for convolution.
75+
:param batch_norm: (bool) Use batch norm or not.
76+
:param acti_func: (str or None) Activation funtion.
5077
"""
5178
def __init__(self, in_channels, out_channels, kernel_size,
5279
dim = 3, stride = 1, padding = 0, output_padding = 0,

pymic/layer/space2channel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import torch
66
import torch.nn as nn
77
import SimpleITK as sitk
8+
89
class SpaceToChannel3D(nn.Module):
10+
"""
11+
Space to channel transform for 3D input."""
912
def __init__(self):
1013
super(SpaceToChannel3D, self).__init__()
1114

@@ -34,6 +37,8 @@ def forward(self, x):
3437
return x7
3538

3639
class ChannelToSpace3D(nn.Module):
40+
"""
41+
Channel to space transform for 3D input."""
3742
def __init__(self):
3843
super(ChannelToSpace3D, self).__init__()
3944

Lines changed: 116 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# pretrained models from pytorch: https://pytorch.org/vision/0.8/models.html
22
from __future__ import print_function, division
33

4+
import itertools
45
import torch
56
import torch.nn as nn
67
import torchvision.models as models
@@ -20,80 +21,149 @@
2021
# 'mnasnet': models.mnasnet1_0
2122
# }
2223

23-
class ResNet18(nn.Module):
24+
class BuiltInNet(nn.Module):
25+
"""
26+
Built-in Network in Pytorch for classification.
27+
Parameters should be set in the `params` dictionary that contains the
28+
following fields:
29+
30+
:param input_chns: (int) Input channel number, default is 3.
31+
:param pretrain: (bool) Using pretrained model or not, default is True.
32+
:param update_mode: (str) The strategy for updating layers: "`all`" means updating
33+
all the layers, and "`last`" (by default) means updating the last layer,
34+
as well as the first layer when `input_chns` is not 3.
35+
"""
2436
def __init__(self, params):
25-
super(ResNet18, self).__init__()
26-
self.params = params
27-
cls_num = params['class_num']
28-
in_chns = params.get('input_chns', 3)
37+
super(BuiltInNet, self).__init__()
38+
self.params = params
39+
self.in_chns = params.get('input_chns', 3)
2940
self.pretrain = params.get('pretrain', True)
30-
self.update_layers = params.get('update_layers', 0)
41+
self.update_mode = params.get('update_mode', "last")
42+
self.net = None
43+
44+
def forward(self, x):
45+
return self.net(x)
46+
47+
def get_parameters_to_update(self):
48+
pass
49+
50+
class ResNet18(BuiltInNet):
51+
"""
52+
ResNet18 for classification.
53+
Parameters should be set in the `params` dictionary that contains the
54+
following fields:
55+
56+
:param input_chns: (int) Input channel number, default is 3.
57+
:param pretrain: (bool) Using pretrained model or not, default is True.
58+
:param update_mode: (str) The strategy for updating layers: "`all`" means updating
59+
all the layers, and "`last`" (by default) means updating the last layer,
60+
as well as the first layer when `input_chns` is not 3.
61+
"""
62+
def __init__(self, params):
63+
super(ResNet18, self).__init__(params)
3164
self.net = models.resnet18(pretrained = self.pretrain)
3265

3366
# replace the last layer
3467
num_ftrs = self.net.fc.in_features
35-
self.net.fc = nn.Linear(num_ftrs, cls_num)
36-
37-
def forward(self, x):
38-
return self.net(x)
68+
self.net.fc = nn.Linear(num_ftrs, params['class_num'])
69+
70+
# replace the first layer when in_chns is not 3
71+
if(self.in_chns != 3):
72+
self.net.conv1 = nn.Conv2d(self.in_chns, 64, kernel_size=(7, 7),
73+
stride=(2, 2), padding=(3, 3), bias=False)
3974

4075
def get_parameters_to_update(self):
41-
if(self.pretrain == False or self.update_layers == 0):
76+
if(self.update_mode == "all"):
4277
return self.net.parameters()
43-
elif(self.update_layers == -1):
44-
return self.net.fc.parameters()
78+
elif(self.update_layers == "last"):
79+
params = self.net.fc.parameters()
80+
if(self.in_chns !=3):
81+
# combining the two iterables into a single one
82+
# see: https://dzone.com/articles/python-joining-multiple
83+
params = itertools.chain()
84+
for pram in [self.net.fc.parameters(), self.net.conv1.parameters()]:
85+
params = itertools.chain(params, pram)
86+
return params
4587
else:
46-
raise(ValueError("update_layers can only be 0 (all layers) " +
47-
"or -1 (the last layer)"))
88+
raise(ValueError("update_mode can only be 'all' or 'last'."))
4889

49-
class VGG16(nn.Module):
90+
class VGG16(BuiltInNet):
91+
"""
92+
VGG16 for classification.
93+
Parameters should be set in the `params` dictionary that contains the
94+
following fields:
95+
96+
:param input_chns: (int) Input channel number, default is 3.
97+
:param pretrain: (bool) Using pretrained model or not, default is True.
98+
:param update_mode: (str) The strategy for updating layers: "`all`" means updating
99+
all the layers, and "`last`" (by default) means updating the last layer,
100+
as well as the first layer when `input_chns` is not 3.
101+
"""
50102
def __init__(self, params):
51-
super(VGG16, self).__init__()
52-
self.params = params
53-
cls_num = params['class_num']
54-
in_chns = params.get('input_chns', 3)
55-
self.pretrain = params.get('pretrain', True)
56-
self.update_layers = params.get('update_layers', 0)
103+
super(VGG16, self).__init__(params)
57104
self.net = models.vgg16(pretrained = self.pretrain)
58105

59106
# replace the last layer
60107
num_ftrs = self.net.classifier[-1].in_features
61-
self.net.classifier[-1] = nn.Linear(num_ftrs, cls_num)
62-
63-
def forward(self, x):
64-
return self.net(x)
108+
self.net.classifier[-1] = nn.Linear(num_ftrs, params['class_num'])
109+
110+
# replace the first layer when in_chns is not 3
111+
if(self.in_chns != 3):
112+
self.net.features[0] = nn.Conv2d(self.in_chns, 64, kernel_size=(3, 3),
113+
stride=(1, 1), padding=(1, 1), bias=False)
65114

66115
def get_parameters_to_update(self):
67-
if(self.pretrain == False or self.update_layers == 0):
116+
if(self.update_mode == "all"):
68117
return self.net.parameters()
69-
elif(self.update_layers == -1):
70-
return self.net.classifier[-1].parameters()
118+
elif(self.update_mode == "last"):
119+
params = self.net.classifier[-1].parameters()
120+
if(self.in_chns !=3):
121+
params = itertools.chain()
122+
for pram in [self.net.classifier[-1].parameters(), self.net.net.features[0].parameters()]:
123+
params = itertools.chain(params, pram)
124+
return params
71125
else:
72-
raise(ValueError("update_layers can only be 0 (all layers) " +
73-
"or -1 (the last layer)"))
126+
raise(ValueError("update_mode can only be 'all' or 'last'."))
127+
128+
class MobileNetV2(BuiltInNet):
129+
"""
130+
MobileNetV2 for classification.
131+
Parameters should be set in the `params` dictionary that contains the
132+
following fields:
74133
75-
class MobileNetV2(nn.Module):
134+
:param input_chns: (int) Input channel number, default is 3.
135+
:param pretrain: (bool) Using pretrained model or not, default is True.
136+
:param update_mode: (str) The strategy for updating layers: "`all`" means updating
137+
all the layers, and "`last`" (by default) means updating the last layer,
138+
as well as the first layer when `input_chns` is not 3.
139+
"""
76140
def __init__(self, params):
77141
super(MobileNetV2, self).__init__()
78-
self.params = params
79-
cls_num = params['class_num']
80-
in_chns = params.get('input_chns', 3)
81-
self.pretrain = params.get('pretrain', True)
82-
self.update_layers = params.get('update_layers', 0)
83142
self.net = models.mobilenet_v2(pretrained = self.pretrain)
84143

85144
# replace the last layer
86145
num_ftrs = self.net.last_channel
87-
self.net.classifier[-1] = nn.Linear(num_ftrs, cls_num)
88-
89-
def forward(self, x):
90-
return self.net(x)
146+
self.net.classifier[-1] = nn.Linear(num_ftrs, params['class_num'])
147+
148+
# replace the first layer when in_chns is not 3
149+
if(self.in_chns != 3):
150+
self.net.features[0][0] = nn.Conv2d(self.in_chns, 32, kernel_size=(3, 3),
151+
stride=(2, 2), padding=(1, 1), bias=False)
91152

92153
def get_parameters_to_update(self):
93-
if(self.pretrain == False or self.update_layers == 0):
154+
if(self.update_mode == "all"):
94155
return self.net.parameters()
95-
elif(self.update_layers == -1):
96-
return self.net.classifier[-1].parameters()
156+
elif(self.update_mode == "last"):
157+
params = self.net.classifier[-1].parameters()
158+
if(self.in_chns !=3):
159+
params = itertools.chain()
160+
for pram in [self.net.classifier[-1].parameters(), self.net.net.features[0][0].parameters()]:
161+
params = itertools.chain(params, pram)
162+
return params
97163
else:
98-
raise(ValueError("update_layers can only be 0 (all layers) " +
99-
"or -1 (the last layer)"))
164+
raise(ValueError("update_mode can only be 'all' or 'last'."))
165+
166+
if __name__ == "__main__":
167+
params = {"class_num": 2, "pretrain": False, "input_chns": 3}
168+
net = ResNet18(params)
169+
print(net)

0 commit comments

Comments
 (0)