99oheight , owidth = 228 , 304
1010
1111class Unpool (nn .Module ):
12- # Unpool: 2*2 unpooling with zero padding
12+ # Unpool: 2*2 unpooling with zero padding
1313 def __init__ (self , num_channels , stride = 2 ):
1414 super (Unpool , self ).__init__ ()
1515
1616 self .num_channels = num_channels
1717 self .stride = stride
1818
1919 # create kernel [1, 0; 0, 0]
20- self .weights = torch .autograd .Variable (torch .zeros (num_channels , 1 , stride , stride ).cuda ()) # currently not compatible with running on CPU
20+ self .weights = torch .autograd .Variable (torch .zeros (num_channels , 1 , stride , stride ).cuda ()) # currently not compatible with running on CPU
2121 self .weights [:,:,0 ,0 ] = 1
2222
2323 def forward (self , x ):
2424 return F .conv_transpose2d (x , self .weights , stride = self .stride , groups = self .num_channels )
2525
2626def weights_init (m ):
2727 # Initialize filters with Gaussian random weights
28- if isinstance (m , nn .Conv2d ):
28+ if isinstance (m , nn .Conv2d ):
2929 n = m .kernel_size [0 ] * m .kernel_size [1 ] * m .out_channels
3030 m .weight .data .normal_ (0 , math .sqrt (2. / n ))
31- if m .bias is not None :
31+ if m .bias is not None :
3232 m .bias .data .zero_ ()
3333 elif isinstance (m , nn .ConvTranspose2d ):
3434 n = m .kernel_size [0 ] * m .kernel_size [1 ] * m .in_channels
3535 m .weight .data .normal_ (0 , math .sqrt (2. / n ))
36- if m .bias is not None :
36+ if m .bias is not None :
3737 m .bias .data .zero_ ()
3838 elif isinstance (m , nn .BatchNorm2d ):
3939 m .weight .data .fill_ (1 )
@@ -63,13 +63,13 @@ class DeConv(Decoder):
6363 def __init__ (self , in_channels , kernel_size ):
6464 assert kernel_size >= 2 , "kernel_size out of range: {}" .format (kernel_size )
6565 super (DeConv , self ).__init__ ()
66-
66+
6767 def convt (in_channels ):
6868 stride = 2
6969 padding = (kernel_size - 1 ) // 2
7070 output_padding = kernel_size % 2
7171 assert - 2 - 2 * padding + kernel_size + output_padding == 0 , "deconv parameters incorrect"
72-
72+
7373 module_name = "deconv{}" .format (kernel_size )
7474 return nn .Sequential (collections .OrderedDict ([
7575 (module_name , nn .ConvTranspose2d (in_channels ,in_channels // 2 ,kernel_size ,
@@ -107,7 +107,7 @@ class UpProj(Decoder):
107107
108108 class UpProjModule (nn .Module ):
109109 # UpProj module has two branches, with a Unpool at the start and a ReLu at the end
110- # upper branch: 5*5 conv -> batchnorm -> ReLU -> 3*3 conv -> batchnorm
110+ # upper branch: 5*5 conv -> batchnorm -> ReLU -> 3*3 conv -> batchnorm
111111 # bottom branch: 5*5 conv -> batchnorm
112112
113113 def __init__ (self , in_channels ):
@@ -145,7 +145,7 @@ def __init__(self, in_channels):
145145def choose_decoder (decoder , in_channels ):
146146 # iheight, iwidth = 10, 8
147147 if decoder [:6 ] == 'deconv' :
148- assert len (decoder )== 7
148+ assert len (decoder )== 7
149149 kernel_size = int (decoder [6 ])
150150 return DeConv (in_channels , kernel_size )
151151 elif decoder == "upproj" :
@@ -161,10 +161,10 @@ def __init__(self, layers, decoder, in_channels=3, out_channels=1, pretrained=Tr
161161
162162 if layers not in [18 , 34 , 50 , 101 , 152 ]:
163163 raise RuntimeError ('Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}' .format (layers ))
164-
164+
165165 super (ResNet , self ).__init__ ()
166166 pretrained_model = torchvision .models .__dict__ ['resnet{}' .format (layers )](pretrained = pretrained )
167-
167+
168168 if in_channels == 3 :
169169 self .conv1 = pretrained_model ._modules ['conv1' ]
170170 self .bn1 = pretrained_model ._modules ['bn1' ]
@@ -173,7 +173,7 @@ def __init__(self, layers, decoder, in_channels=3, out_channels=1, pretrained=Tr
173173 self .bn1 = nn .BatchNorm2d (64 )
174174 weights_init (self .conv1 )
175175 weights_init (self .bn1 )
176-
176+
177177 self .relu = pretrained_model ._modules ['relu' ]
178178 self .maxpool = pretrained_model ._modules ['maxpool' ]
179179 self .layer1 = pretrained_model ._modules ['layer1' ]
@@ -187,6 +187,8 @@ def __init__(self, layers, decoder, in_channels=3, out_channels=1, pretrained=Tr
187187 # define number of intermediate channels
188188 if layers <= 34 :
189189 num_channels = 512
190+ # Need to modify owidth for ResNet18 model.
191+ owidth = 912
190192 elif layers >= 50 :
191193 num_channels = 2048
192194
0 commit comments