Skip to content

Commit cbf1f46

Browse files
author
Fangchang Ma
committed
added upconv and upproj
1 parent ddba1cd commit cbf1f46

File tree

2 files changed

+98
-20
lines changed

2 files changed

+98
-20
lines changed

main.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@
5353
' (default: deconv2)')
5454
parser.add_argument('-j', '--workers', default=10, type=int, metavar='N',
5555
help='number of data loading workers (default: 10)')
56-
parser.add_argument('--epochs', default=30, type=int, metavar='N',
57-
help='number of total epochs to run (default: 30)')
56+
parser.add_argument('--epochs', default=15, type=int, metavar='N',
57+
help='number of total epochs to run (default: 15)')
5858
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
5959
help='manual epoch number (useful on restarts)')
6060
parser.add_argument('-c', '--criterion', metavar='LOSS', default='l1',
@@ -106,10 +106,9 @@ def main():
106106
# define loss function (criterion) and optimizer
107107
if args.criterion == 'l2':
108108
criterion = criteria.MaskedMSELoss().cuda()
109-
out_channels = 1
110109
elif args.criterion == 'l1':
111110
criterion = criteria.MaskedL1Loss().cuda()
112-
out_channels = 1
111+
out_channels = 1
113112

114113
# Data loading code
115114
print("=> creating data loaders ...")
@@ -157,6 +156,7 @@ def main():
157156
print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
158157
else:
159158
print("=> no checkpoint found at '{}'".format(args.resume))
159+
return
160160

161161
# create new model
162162
else:
@@ -358,8 +358,8 @@ def save_checkpoint(state, is_best, epoch):
358358
os.remove(prev_checkpoint_filename)
359359

360360
def adjust_learning_rate(optimizer, epoch):
361-
"""Sets the learning rate to the initial LR decayed by 10 every 10 epochs"""
362-
lr = args.lr * (0.1 ** (epoch // 10))
361+
"""Sets the learning rate to the initial LR decayed by 10 every 5 epochs"""
362+
lr = args.lr * (0.1 ** (epoch // 5))
363363
for param_group in optimizer.param_groups:
364364
param_group['lr'] = lr
365365

models.py

Lines changed: 92 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
11
import os
22
import torch
33
import torch.nn as nn
4+
import torch.nn.functional as F
45
import torchvision.models
56
import collections
67
import math
78

89
oheight, owidth = 228, 304
910

11+
class Unpool(nn.Module):
12+
# Unpool: 2*2 unpooling with zero padding
13+
def __init__(self, num_channels, stride=2):
14+
super(Unpool, self).__init__()
15+
16+
self.num_channels = num_channels
17+
self.stride = stride
18+
19+
# create kernel [1, 0; 0, 0]
20+
self.weights = torch.autograd.Variable(torch.zeros(num_channels, 1, stride, stride).cuda()) # currently not compatible with running on CPU
21+
self.weights[:,:,0,0] = 1
22+
23+
def forward(self, x):
24+
return F.conv_transpose2d(x, self.weights, stride=self.stride, groups=self.num_channels)
25+
1026
def weights_init(m):
1127
# Initialize filters with Gaussian random weights
1228
if isinstance(m, nn.Conv2d):
@@ -26,7 +42,7 @@ def weights_init(m):
2642
class Decoder(nn.Module):
2743
# Decoder is the base class for all decoders
2844

29-
names = ['deconv{}'.format(i) for i in range(2,10)]
45+
names = ['deconv2', 'deconv3', 'upconv', 'upproj']
3046

3147
def __init__(self):
3248
super(Decoder, self).__init__()
@@ -67,15 +83,77 @@ def convt(in_channels):
6783
self.layer3 = convt(in_channels // (2 ** 2))
6884
self.layer4 = convt(in_channels // (2 ** 3))
6985

70-
71-
def choose_decoder(decoder):
72-
assert decoder[:6] == 'deconv'
73-
assert len(decoder)==7
74-
75-
num_channels = 512
76-
iheight, iwidth = 10, 8
77-
kernel_size = int(decoder[6])
78-
return DeConv(num_channels, kernel_size)
86+
class UpConv(Decoder):
87+
# UpConv decoder consists of 4 upconv modules with decreasing number of channels and increasing feature map size
88+
def upconv_module(self, in_channels):
89+
# UpConv module: unpool -> 5*5 conv -> batchnorm -> ReLU
90+
upconv = nn.Sequential(collections.OrderedDict([
91+
('unpool', Unpool(in_channels)),
92+
('conv', nn.Conv2d(in_channels,in_channels//2,kernel_size=5,stride=1,padding=2,bias=False)),
93+
('batchnorm', nn.BatchNorm2d(in_channels//2)),
94+
('relu', nn.ReLU()),
95+
]))
96+
return upconv
97+
98+
def __init__(self, in_channels):
99+
super(UpConv, self).__init__()
100+
self.layer1 = self.upconv_module(in_channels)
101+
self.layer2 = self.upconv_module(in_channels//2)
102+
self.layer3 = self.upconv_module(in_channels//4)
103+
self.layer4 = self.upconv_module(in_channels//8)
104+
105+
class UpProj(Decoder):
106+
# UpProj decoder consists of 4 upproj modules with decreasing number of channels and increasing feature map size
107+
108+
class UpProjModule(nn.Module):
109+
# UpProj module has two branches, with a Unpool at the start and a ReLu at the end
110+
# upper branch: 5*5 conv -> batchnorm -> ReLU -> 3*3 conv -> batchnorm
111+
# bottom branch: 5*5 conv -> batchnorm
112+
113+
def __init__(self, in_channels):
114+
super(UpProj.UpProjModule, self).__init__()
115+
out_channels = in_channels//2
116+
self.unpool = Unpool(in_channels)
117+
self.upper_branch = nn.Sequential(collections.OrderedDict([
118+
('conv1', nn.Conv2d(in_channels,out_channels,kernel_size=5,stride=1,padding=2,bias=False)),
119+
('batchnorm1', nn.BatchNorm2d(out_channels)),
120+
('relu', nn.ReLU()),
121+
('conv2', nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False)),
122+
('batchnorm2', nn.BatchNorm2d(out_channels)),
123+
]))
124+
self.bottom_branch = nn.Sequential(collections.OrderedDict([
125+
('conv', nn.Conv2d(in_channels,out_channels,kernel_size=5,stride=1,padding=2,bias=False)),
126+
('batchnorm', nn.BatchNorm2d(out_channels)),
127+
]))
128+
self.relu = nn.ReLU()
129+
130+
def forward(self, x):
131+
x = self.unpool(x)
132+
x1 = self.upper_branch(x)
133+
x2 = self.bottom_branch(x)
134+
x = x1 + x2
135+
x = self.relu(x)
136+
return x
137+
138+
def __init__(self, in_channels):
139+
super(UpProj, self).__init__()
140+
self.layer1 = self.UpProjModule(in_channels)
141+
self.layer2 = self.UpProjModule(in_channels//2)
142+
self.layer3 = self.UpProjModule(in_channels//4)
143+
self.layer4 = self.UpProjModule(in_channels//8)
144+
145+
def choose_decoder(decoder, in_channels):
146+
# iheight, iwidth = 10, 8
147+
if decoder[:6] == 'deconv':
148+
assert len(decoder)==7
149+
kernel_size = int(decoder[6])
150+
return DeConv(in_channels, kernel_size)
151+
elif decoder == "upproj":
152+
return UpProj(in_channels)
153+
elif decoder == "upconv":
154+
return UpConv(in_channels)
155+
else:
156+
assert False, "invalid option for decoder: {}".format(decoder)
79157

80158

81159
class ResNet(nn.Module):
@@ -112,12 +190,12 @@ def __init__(self, layers, decoder, in_channels=3, out_channels=1, pretrained=Tr
112190
elif layers >= 50:
113191
num_channels = 2048
114192

115-
self.conv2 = nn.Conv2d(num_channels,512,kernel_size=1,bias=False)
116-
self.bn2 = nn.BatchNorm2d(512)
117-
self.decoder = choose_decoder(decoder)
193+
self.conv2 = nn.Conv2d(num_channels,num_channels//2,kernel_size=1,bias=False)
194+
self.bn2 = nn.BatchNorm2d(num_channels//2)
195+
self.decoder = choose_decoder(decoder, num_channels//2)
118196

119197
# setting bias=true doesn't improve accuracy
120-
self.conv3 = nn.Conv2d(32,out_channels,kernel_size=3,stride=1,padding=1,bias=False)
198+
self.conv3 = nn.Conv2d(num_channels//32,out_channels,kernel_size=3,stride=1,padding=1,bias=False)
121199
self.bilinear = nn.Upsample(size=(oheight, owidth), mode='bilinear')
122200

123201
# weight init

0 commit comments

Comments
 (0)