11import os
22import torch
33import torch .nn as nn
4+ import torch .nn .functional as F
45import torchvision .models
56import collections
67import math
78
89oheight , owidth = 228 , 304
910
11+ class Unpool (nn .Module ):
12+ # Unpool: 2*2 unpooling with zero padding
13+ def __init__ (self , num_channels , stride = 2 ):
14+ super (Unpool , self ).__init__ ()
15+
16+ self .num_channels = num_channels
17+ self .stride = stride
18+
19+ # create kernel [1, 0; 0, 0]
20+ self .weights = torch .autograd .Variable (torch .zeros (num_channels , 1 , stride , stride ).cuda ()) # currently not compatible with running on CPU
21+ self .weights [:,:,0 ,0 ] = 1
22+
23+ def forward (self , x ):
24+ return F .conv_transpose2d (x , self .weights , stride = self .stride , groups = self .num_channels )
25+
1026def weights_init (m ):
1127 # Initialize filters with Gaussian random weights
1228 if isinstance (m , nn .Conv2d ):
@@ -26,7 +42,7 @@ def weights_init(m):
2642class Decoder (nn .Module ):
2743 # Decoder is the base class for all decoders
2844
29- names = ['deconv{}' . format ( i ) for i in range ( 2 , 10 ) ]
45+ names = ['deconv2' , 'deconv3' , 'upconv' , 'upproj' ]
3046
3147 def __init__ (self ):
3248 super (Decoder , self ).__init__ ()
@@ -67,15 +83,77 @@ def convt(in_channels):
6783 self .layer3 = convt (in_channels // (2 ** 2 ))
6884 self .layer4 = convt (in_channels // (2 ** 3 ))
6985
70-
71- def choose_decoder (decoder ):
72- assert decoder [:6 ] == 'deconv'
73- assert len (decoder )== 7
74-
75- num_channels = 512
76- iheight , iwidth = 10 , 8
77- kernel_size = int (decoder [6 ])
78- return DeConv (num_channels , kernel_size )
86+ class UpConv (Decoder ):
87+ # UpConv decoder consists of 4 upconv modules with decreasing number of channels and increasing feature map size
88+ def upconv_module (self , in_channels ):
89+ # UpConv module: unpool -> 5*5 conv -> batchnorm -> ReLU
90+ upconv = nn .Sequential (collections .OrderedDict ([
91+ ('unpool' , Unpool (in_channels )),
92+ ('conv' , nn .Conv2d (in_channels ,in_channels // 2 ,kernel_size = 5 ,stride = 1 ,padding = 2 ,bias = False )),
93+ ('batchnorm' , nn .BatchNorm2d (in_channels // 2 )),
94+ ('relu' , nn .ReLU ()),
95+ ]))
96+ return upconv
97+
98+ def __init__ (self , in_channels ):
99+ super (UpConv , self ).__init__ ()
100+ self .layer1 = self .upconv_module (in_channels )
101+ self .layer2 = self .upconv_module (in_channels // 2 )
102+ self .layer3 = self .upconv_module (in_channels // 4 )
103+ self .layer4 = self .upconv_module (in_channels // 8 )
104+
105+ class UpProj (Decoder ):
106+ # UpProj decoder consists of 4 upproj modules with decreasing number of channels and increasing feature map size
107+
108+ class UpProjModule (nn .Module ):
109+ # UpProj module has two branches, with a Unpool at the start and a ReLu at the end
110+ # upper branch: 5*5 conv -> batchnorm -> ReLU -> 3*3 conv -> batchnorm
111+ # bottom branch: 5*5 conv -> batchnorm
112+
113+ def __init__ (self , in_channels ):
114+ super (UpProj .UpProjModule , self ).__init__ ()
115+ out_channels = in_channels // 2
116+ self .unpool = Unpool (in_channels )
117+ self .upper_branch = nn .Sequential (collections .OrderedDict ([
118+ ('conv1' , nn .Conv2d (in_channels ,out_channels ,kernel_size = 5 ,stride = 1 ,padding = 2 ,bias = False )),
119+ ('batchnorm1' , nn .BatchNorm2d (out_channels )),
120+ ('relu' , nn .ReLU ()),
121+ ('conv2' , nn .Conv2d (out_channels ,out_channels ,kernel_size = 3 ,stride = 1 ,padding = 1 ,bias = False )),
122+ ('batchnorm2' , nn .BatchNorm2d (out_channels )),
123+ ]))
124+ self .bottom_branch = nn .Sequential (collections .OrderedDict ([
125+ ('conv' , nn .Conv2d (in_channels ,out_channels ,kernel_size = 5 ,stride = 1 ,padding = 2 ,bias = False )),
126+ ('batchnorm' , nn .BatchNorm2d (out_channels )),
127+ ]))
128+ self .relu = nn .ReLU ()
129+
130+ def forward (self , x ):
131+ x = self .unpool (x )
132+ x1 = self .upper_branch (x )
133+ x2 = self .bottom_branch (x )
134+ x = x1 + x2
135+ x = self .relu (x )
136+ return x
137+
138+ def __init__ (self , in_channels ):
139+ super (UpProj , self ).__init__ ()
140+ self .layer1 = self .UpProjModule (in_channels )
141+ self .layer2 = self .UpProjModule (in_channels // 2 )
142+ self .layer3 = self .UpProjModule (in_channels // 4 )
143+ self .layer4 = self .UpProjModule (in_channels // 8 )
144+
145+ def choose_decoder (decoder , in_channels ):
146+ # iheight, iwidth = 10, 8
147+ if decoder [:6 ] == 'deconv' :
148+ assert len (decoder )== 7
149+ kernel_size = int (decoder [6 ])
150+ return DeConv (in_channels , kernel_size )
151+ elif decoder == "upproj" :
152+ return UpProj (in_channels )
153+ elif decoder == "upconv" :
154+ return UpConv (in_channels )
155+ else :
156+ assert False , "invalid option for decoder: {}" .format (decoder )
79157
80158
81159class ResNet (nn .Module ):
@@ -112,12 +190,12 @@ def __init__(self, layers, decoder, in_channels=3, out_channels=1, pretrained=Tr
112190 elif layers >= 50 :
113191 num_channels = 2048
114192
115- self .conv2 = nn .Conv2d (num_channels ,512 ,kernel_size = 1 ,bias = False )
116- self .bn2 = nn .BatchNorm2d (512 )
117- self .decoder = choose_decoder (decoder )
193+ self .conv2 = nn .Conv2d (num_channels ,num_channels // 2 ,kernel_size = 1 ,bias = False )
194+ self .bn2 = nn .BatchNorm2d (num_channels // 2 )
195+ self .decoder = choose_decoder (decoder , num_channels // 2 )
118196
119197 # setting bias=true doesn't improve accuracy
120- self .conv3 = nn .Conv2d (32 ,out_channels ,kernel_size = 3 ,stride = 1 ,padding = 1 ,bias = False )
198+ self .conv3 = nn .Conv2d (num_channels // 32 ,out_channels ,kernel_size = 3 ,stride = 1 ,padding = 1 ,bias = False )
121199 self .bilinear = nn .Upsample (size = (oheight , owidth ), mode = 'bilinear' )
122200
123201 # weight init
0 commit comments