Skip to content

Commit 23def5d

Browse files
committed
modify main.py and model.py to load KITTI dataset
1 parent 1d611bd commit 23def5d

File tree

2 files changed

+33
-17
lines changed

2 files changed

+33
-17
lines changed

main.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sys
55
import csv
66
import numpy as np
7+
import errno
78

89
import torch
910
import torch.nn as nn
@@ -13,6 +14,7 @@
1314
import torch.utils.data
1415

1516
from nyu_dataloader import NYUDataset
17+
from kitti_dataloader import KITTIDataset
1618
from models import Decoder, ResNet
1719
from metrics import AverageMeter, Result
1820
from dense_to_sparse import UniformSampling, SimulatedStereo
@@ -21,7 +23,7 @@
2123

2224
model_names = ['resnet18', 'resnet50']
2325
loss_names = ['l1', 'l2']
24-
data_names = ['nyudepthv2']
26+
data_names = ['nyudepthv2', 'kitti']
2527
sparsifier_names = [x.name for x in [UniformSampling, SimulatedStereo]]
2628
decoder_names = Decoder.names
2729
modality_names = NYUDataset.modality_names
@@ -134,15 +136,27 @@ def main():
134136
traindir = os.path.join('data', args.data, 'train')
135137
valdir = os.path.join('data', args.data, 'val')
136138

137-
train_dataset = NYUDataset(traindir, type='train',
138-
modality=args.modality, sparsifier=sparsifier)
139+
if args.data == 'nyudepthv2':
140+
train_dataset = NYUDataset(traindir, type='train',
141+
modality=args.modality, sparsifier=sparsifier)
142+
val_dataset = NYUDataset(valdir, type='val',
143+
modality=args.modality, sparsifier=sparsifier)
144+
145+
elif args.data == 'kitti':
146+
train_dataset = KITTIDataset(traindir, type='train',
147+
modality=args.modality, sparsifier=sparsifier)
148+
val_dataset = KITTIDataset(valdir, type='val',
149+
modality=args.modality, sparsifier=sparsifier)
150+
151+
else:
152+
raise RuntimeError('Dataset not found.' +
153+
'The dataset must be either of nyudepthv2 or kitti.')
154+
139155
train_loader = torch.utils.data.DataLoader(
140156
train_dataset, batch_size=args.batch_size, shuffle=True,
141157
num_workers=args.workers, pin_memory=True, sampler=None)
142158

143159
# set batch size to be 1 for validation
144-
val_dataset = NYUDataset(valdir, type='val',
145-
modality=args.modality, sparsifier=sparsifier)
146160
val_loader = torch.utils.data.DataLoader(val_dataset,
147161
batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)
148162

models.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,31 @@
99
oheight, owidth = 228, 304
1010

1111
class Unpool(nn.Module):
12-
# Unpool: 2*2 unpooling with zero padding
12+
# Unpool: 2*2 unpooling with zero padding
1313
def __init__(self, num_channels, stride=2):
1414
super(Unpool, self).__init__()
1515

1616
self.num_channels = num_channels
1717
self.stride = stride
1818

1919
# create kernel [1, 0; 0, 0]
20-
self.weights = torch.autograd.Variable(torch.zeros(num_channels, 1, stride, stride).cuda()) # currently not compatible with running on CPU
20+
self.weights = torch.autograd.Variable(torch.zeros(num_channels, 1, stride, stride).cuda()) # currently not compatible with running on CPU
2121
self.weights[:,:,0,0] = 1
2222

2323
def forward(self, x):
2424
return F.conv_transpose2d(x, self.weights, stride=self.stride, groups=self.num_channels)
2525

2626
def weights_init(m):
2727
# Initialize filters with Gaussian random weights
28-
if isinstance(m, nn.Conv2d):
28+
if isinstance(m, nn.Conv2d):
2929
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
3030
m.weight.data.normal_(0, math.sqrt(2. / n))
31-
if m.bias is not None:
31+
if m.bias is not None:
3232
m.bias.data.zero_()
3333
elif isinstance(m, nn.ConvTranspose2d):
3434
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
3535
m.weight.data.normal_(0, math.sqrt(2. / n))
36-
if m.bias is not None:
36+
if m.bias is not None:
3737
m.bias.data.zero_()
3838
elif isinstance(m, nn.BatchNorm2d):
3939
m.weight.data.fill_(1)
@@ -63,13 +63,13 @@ class DeConv(Decoder):
6363
def __init__(self, in_channels, kernel_size):
6464
assert kernel_size>=2, "kernel_size out of range: {}".format(kernel_size)
6565
super(DeConv, self).__init__()
66-
66+
6767
def convt(in_channels):
6868
stride = 2
6969
padding = (kernel_size - 1) // 2
7070
output_padding = kernel_size % 2
7171
assert -2 - 2*padding + kernel_size + output_padding == 0, "deconv parameters incorrect"
72-
72+
7373
module_name = "deconv{}".format(kernel_size)
7474
return nn.Sequential(collections.OrderedDict([
7575
(module_name, nn.ConvTranspose2d(in_channels,in_channels//2,kernel_size,
@@ -107,7 +107,7 @@ class UpProj(Decoder):
107107

108108
class UpProjModule(nn.Module):
109109
# UpProj module has two branches, with a Unpool at the start and a ReLu at the end
110-
# upper branch: 5*5 conv -> batchnorm -> ReLU -> 3*3 conv -> batchnorm
110+
# upper branch: 5*5 conv -> batchnorm -> ReLU -> 3*3 conv -> batchnorm
111111
# bottom branch: 5*5 conv -> batchnorm
112112

113113
def __init__(self, in_channels):
@@ -145,7 +145,7 @@ def __init__(self, in_channels):
145145
def choose_decoder(decoder, in_channels):
146146
# iheight, iwidth = 10, 8
147147
if decoder[:6] == 'deconv':
148-
assert len(decoder)==7
148+
assert len(decoder)==7
149149
kernel_size = int(decoder[6])
150150
return DeConv(in_channels, kernel_size)
151151
elif decoder == "upproj":
@@ -161,10 +161,10 @@ def __init__(self, layers, decoder, in_channels=3, out_channels=1, pretrained=Tr
161161

162162
if layers not in [18, 34, 50, 101, 152]:
163163
raise RuntimeError('Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers))
164-
164+
165165
super(ResNet, self).__init__()
166166
pretrained_model = torchvision.models.__dict__['resnet{}'.format(layers)](pretrained=pretrained)
167-
167+
168168
if in_channels == 3:
169169
self.conv1 = pretrained_model._modules['conv1']
170170
self.bn1 = pretrained_model._modules['bn1']
@@ -173,7 +173,7 @@ def __init__(self, layers, decoder, in_channels=3, out_channels=1, pretrained=Tr
173173
self.bn1 = nn.BatchNorm2d(64)
174174
weights_init(self.conv1)
175175
weights_init(self.bn1)
176-
176+
177177
self.relu = pretrained_model._modules['relu']
178178
self.maxpool = pretrained_model._modules['maxpool']
179179
self.layer1 = pretrained_model._modules['layer1']
@@ -187,6 +187,8 @@ def __init__(self, layers, decoder, in_channels=3, out_channels=1, pretrained=Tr
187187
# define number of intermediate channels
188188
if layers <= 34:
189189
num_channels = 512
190+
# Need to modify owidth for ResNet18 model.
191+
owidth = 912
190192
elif layers >= 50:
191193
num_channels = 2048
192194

0 commit comments

Comments
 (0)