Skip to content

Commit fff3469

Browse files
author
Fangchang Ma
committed
migrated to pytorch 0.4.0; fixed bug with pretraining option
1 parent 237b0c4 commit fff3469

File tree

5 files changed

+83
-81
lines changed

5 files changed

+83
-81
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Thanks to [Tim](https://github.com/timethy) for his contribution.
2626
0. [Citation](#citation)
2727

2828
## Requirements
29+
This code was tested with Python 3 and PyTorch 0.4.0.
2930
- Install [PyTorch](http://pytorch.org/) on a machine with CUDA GPU.
3031
- Install the [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) and other dependencies (files in our pre-processed datasets are in HDF5 formats).
3132
```bash
@@ -58,7 +59,7 @@ Training results will be saved under the `results` folder.
5859
## Testing
5960
To test the performance of a trained model, simply run main.py with the `-e` option, along with other model options. For instance,
6061
```bash
61-
python3 main.py -e
62+
python3 main.py -e -a resnet50 -d deconv3 -m rgbd -s 100
6263
```
6364

6465
## Trained Models

main.py

Lines changed: 44 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import argparse
22
import os
3-
import shutil
43
import time
54
import sys
65
import csv
@@ -86,8 +85,9 @@
8685
help='path to latest checkpoint (default: none)')
8786
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
8887
help='evaluate model on validation set')
89-
parser.add_argument('--pretrained', dest='pretrained', choices=['True', 'False'],
90-
default=True, help='use ImageNet pre-trained weights (default: True)')
88+
parser.add_argument('--no-pretrain', dest='pretrained', action='store_false',
89+
help='not to use ImageNet pre-trained weights')
90+
parser.set_defaults(pretrained=True)
9191

9292
args = parser.parse_args()
9393
if args.modality == 'rgb' and args.num_samples != 0:
@@ -96,7 +96,6 @@
9696
if args.modality == 'rgb' and args.max_depth != 0.0:
9797
print("max depth is forced to be 0.0 when input modality is rgb/rgbd")
9898
args.max_depth = 0.0
99-
args.pretrained = (args.pretrained == "True")
10099
print(args)
101100

102101
fieldnames = ['mse', 'rmse', 'absrel', 'lg10', 'mae',
@@ -116,9 +115,7 @@ def main():
116115
sparsifier = SimulatedStereo(num_samples=args.num_samples, max_depth=max_depth)
117116

118117
# create results folder, if not already exists
119-
output_directory = os.path.join('results',
120-
'{}.sparsifier={}.modality={}.arch={}.decoder={}.criterion={}.lr={}.bs={}'.
121-
format(args.data, sparsifier, args.modality, args.arch, args.decoder, args.criterion, args.lr, args.batch_size))
118+
output_directory = utils.get_output_directory(args)
122119
if not os.path.exists(output_directory):
123120
os.makedirs(output_directory)
124121
train_csv = os.path.join(output_directory, 'train.csv')
@@ -154,31 +151,28 @@ def main():
154151
# evaluation mode
155152
if args.evaluate:
156153
best_model_filename = os.path.join(output_directory, 'model_best.pth.tar')
157-
if os.path.isfile(best_model_filename):
158-
print("=> loading best model '{}'".format(best_model_filename))
159-
checkpoint = torch.load(best_model_filename)
160-
args.start_epoch = checkpoint['epoch']
161-
best_result = checkpoint['best_result']
162-
model = checkpoint['model']
163-
print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
164-
else:
165-
print("=> no best model found at '{}'".format(best_model_filename))
154+
assert os.path.isfile(best_model_filename), \
155+
"=> no best model found at '{}'".format(best_model_filename)
156+
print("=> loading best model '{}'".format(best_model_filename))
157+
checkpoint = torch.load(best_model_filename)
158+
args.start_epoch = checkpoint['epoch']
159+
best_result = checkpoint['best_result']
160+
model = checkpoint['model']
161+
print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
166162
validate(val_loader, model, checkpoint['epoch'], write_to_file=False)
167163
return
168164

169165
# optionally resume from a checkpoint
170166
elif args.resume:
171-
if os.path.isfile(args.resume):
172-
print("=> loading checkpoint '{}'".format(args.resume))
173-
checkpoint = torch.load(args.resume)
174-
args.start_epoch = checkpoint['epoch']+1
175-
best_result = checkpoint['best_result']
176-
model = checkpoint['model']
177-
optimizer = checkpoint['optimizer']
178-
print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
179-
else:
180-
print("=> no checkpoint found at '{}'".format(args.resume))
181-
return
167+
assert os.path.isfile(args.resume), \
168+
"=> no checkpoint found at '{}'".format(args.resume)
169+
print("=> loading checkpoint '{}'".format(args.resume))
170+
checkpoint = torch.load(args.resume)
171+
args.start_epoch = checkpoint['epoch']+1
172+
best_result = checkpoint['best_result']
173+
model = checkpoint['model']
174+
optimizer = checkpoint['optimizer']
175+
print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
182176

183177
# create new model
184178
else:
@@ -211,7 +205,7 @@ def main():
211205
print("=> model transferred to GPU.")
212206

213207
for epoch in range(args.start_epoch, args.epochs):
214-
adjust_learning_rate(optimizer, epoch)
208+
utils.adjust_learning_rate(optimizer, epoch, args.lr)
215209

216210
# train for one epoch
217211
train(train_loader, model, criterion, optimizer, epoch)
@@ -230,13 +224,14 @@ def main():
230224
img_filename = output_directory + '/comparison_best.png'
231225
utils.save_image(img_merge, img_filename)
232226

233-
save_checkpoint({
227+
utils.save_checkpoint({
228+
'args': args,
234229
'epoch': epoch,
235230
'arch': args.arch,
236231
'model': model,
237232
'best_result': best_result,
238233
'optimizer' : optimizer,
239-
}, is_best, epoch)
234+
}, is_best, epoch, output_directory)
240235

241236

242237
def train(train_loader, model, criterion, optimizer, epoch):
@@ -249,35 +244,32 @@ def train(train_loader, model, criterion, optimizer, epoch):
249244
for i, (input, target) in enumerate(train_loader):
250245

251246
input, target = input.cuda(), target.cuda()
252-
input_var = torch.autograd.Variable(input)
253-
target_var = torch.autograd.Variable(target)
254-
torch.cuda.synchronize()
247+
# torch.cuda.synchronize()
255248
data_time = time.time() - end
256249

257-
# compute depth_pred
250+
# compute pred
258251
end = time.time()
259-
depth_pred = model(input_var)
260-
loss = criterion(depth_pred, target_var)
252+
pred = model(input)
253+
loss = criterion(pred, target)
261254
optimizer.zero_grad()
262255
loss.backward() # compute gradient and do SGD step
263256
optimizer.step()
264-
torch.cuda.synchronize()
257+
# torch.cuda.synchronize()
265258
gpu_time = time.time() - end
266259

267260
# measure accuracy and record loss
268261
result = Result()
269-
output1 = torch.index_select(depth_pred.data, 1, torch.cuda.LongTensor([0]))
270-
result.evaluate(output1, target)
262+
result.evaluate(pred.data, target.data)
271263
average_meter.update(result, gpu_time, data_time, input.size(0))
272264
end = time.time()
273265

274266
if (i + 1) % args.print_freq == 0:
275267
print('=> output: {}'.format(output_directory))
276268
print('Train Epoch: {0} [{1}/{2}]\t'
277269
't_Data={data_time:.3f}({average.data_time:.3f}) '
278-
't_GPU={gpu_time:.3f}({average.gpu_time:.3f}) '
270+
't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t'
279271
'RMSE={result.rmse:.2f}({average.rmse:.2f}) '
280-
'MAE={result.mae:.2f}({average.mae:.2f}) '
272+
'MAE={result.mae:.2f}({average.mae:.2f})\n\t'
281273
'Delta1={result.delta1:.3f}({average.delta1:.3f}) '
282274
'REL={result.absrel:.3f}({average.absrel:.3f}) '
283275
'Lg10={result.lg10:.3f}({average.lg10:.3f}) '.format(
@@ -301,21 +293,19 @@ def validate(val_loader, model, epoch, write_to_file=True):
301293
end = time.time()
302294
for i, (input, target) in enumerate(val_loader):
303295
input, target = input.cuda(), target.cuda()
304-
input_var = torch.autograd.Variable(input)
305-
target_var = torch.autograd.Variable(target)
306-
torch.cuda.synchronize()
296+
# torch.cuda.synchronize()
307297
data_time = time.time() - end
308298

309299
# compute output
310300
end = time.time()
311-
depth_pred = model(input_var)
312-
torch.cuda.synchronize()
301+
with torch.no_grad():
302+
pred = model(input)
303+
# torch.cuda.synchronize()
313304
gpu_time = time.time() - end
314305

315306
# measure accuracy and record loss
316307
result = Result()
317-
output1 = torch.index_select(depth_pred.data, 1, torch.cuda.LongTensor([0]))
318-
result.evaluate(output1, target)
308+
result.evaluate(pred.data, target.data)
319309
average_meter.update(result, gpu_time, data_time, input.size(0))
320310
end = time.time()
321311

@@ -332,24 +322,24 @@ def validate(val_loader, model, epoch, write_to_file=True):
332322

333323
if i == 0:
334324
if args.modality == 'rgbd':
335-
img_merge = utils.merge_into_row_with_gt(rgb, depth, target, depth_pred)
325+
img_merge = utils.merge_into_row_with_gt(rgb, depth, target, pred)
336326
else:
337-
img_merge = utils.merge_into_row(rgb, target, depth_pred)
327+
img_merge = utils.merge_into_row(rgb, target, pred)
338328
elif (i < 8*skip) and (i % skip == 0):
339329
if args.modality == 'rgbd':
340-
row = utils.merge_into_row_with_gt(rgb, depth, target, depth_pred)
330+
row = utils.merge_into_row_with_gt(rgb, depth, target, pred)
341331
else:
342-
row = utils.merge_into_row(rgb, target, depth_pred)
332+
row = utils.merge_into_row(rgb, target, pred)
343333
img_merge = utils.add_row(img_merge, row)
344334
elif i == 8*skip:
345335
filename = output_directory + '/comparison_' + str(epoch) + '.png'
346336
utils.save_image(img_merge, filename)
347337

348338
if (i+1) % args.print_freq == 0:
349339
print('Test: [{0}/{1}]\t'
350-
't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\t'
340+
't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t'
351341
'RMSE={result.rmse:.2f}({average.rmse:.2f}) '
352-
'MAE={result.mae:.2f}({average.mae:.2f}) '
342+
'MAE={result.mae:.2f}({average.mae:.2f})\n\t'
353343
'Delta1={result.delta1:.3f}({average.delta1:.3f}) '
354344
'REL={result.absrel:.3f}({average.absrel:.3f}) '
355345
'Lg10={result.lg10:.3f}({average.lg10:.3f}) '.format(
@@ -375,22 +365,5 @@ def validate(val_loader, model, epoch, write_to_file=True):
375365

376366
return avg, img_merge
377367

378-
def save_checkpoint(state, is_best, epoch):
379-
checkpoint_filename = os.path.join(output_directory, 'checkpoint-' + str(epoch) + '.pth.tar')
380-
torch.save(state, checkpoint_filename)
381-
if is_best:
382-
best_filename = os.path.join(output_directory, 'model_best.pth.tar')
383-
shutil.copyfile(checkpoint_filename, best_filename)
384-
if epoch > 0:
385-
prev_checkpoint_filename = os.path.join(output_directory, 'checkpoint-' + str(epoch-1) + '.pth.tar')
386-
if os.path.exists(prev_checkpoint_filename):
387-
os.remove(prev_checkpoint_filename)
388-
389-
def adjust_learning_rate(optimizer, epoch):
390-
"""Sets the learning rate to the initial LR decayed by 10 every 5 epochs"""
391-
lr = args.lr * (0.1 ** (epoch // 5))
392-
for param_group in optimizer.param_groups:
393-
param_group['lr'] = lr
394-
395368
if __name__ == '__main__':
396369
main()

metrics.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,24 +35,24 @@ def evaluate(self, output, target):
3535

3636
abs_diff = (output - target).abs()
3737

38-
self.mse = (torch.pow(abs_diff, 2)).mean()
38+
self.mse = float((torch.pow(abs_diff, 2)).mean())
3939
self.rmse = math.sqrt(self.mse)
40-
self.mae = abs_diff.mean()
41-
self.lg10 = (log10(output) - log10(target)).abs().mean()
42-
self.absrel = (abs_diff / target).mean()
40+
self.mae = float(abs_diff.mean())
41+
self.lg10 = float((log10(output) - log10(target)).abs().mean())
42+
self.absrel = float((abs_diff / target).mean())
4343

4444
maxRatio = torch.max(output / target, target / output)
45-
self.delta1 = (maxRatio < 1.25).float().mean()
46-
self.delta2 = (maxRatio < 1.25 ** 2).float().mean()
47-
self.delta3 = (maxRatio < 1.25 ** 3).float().mean()
45+
self.delta1 = float((maxRatio < 1.25).float().mean())
46+
self.delta2 = float((maxRatio < 1.25 ** 2).float().mean())
47+
self.delta3 = float((maxRatio < 1.25 ** 3).float().mean())
4848
self.data_time = 0
4949
self.gpu_time = 0
5050

5151
inv_output = 1 / output
5252
inv_target = 1 / target
5353
abs_inv_diff = (inv_output - inv_target).abs()
5454
self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean())
55-
self.imae = abs_inv_diff.mean()
55+
self.imae = float(abs_inv_diff.mean())
5656

5757

5858
class AverageMeter(object):

models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def __init__(self, layers, decoder, in_channels=3, out_channels=1, pretrained=Tr
196196

197197
# setting bias=true doesn't improve accuracy
198198
self.conv3 = nn.Conv2d(num_channels//32,out_channels,kernel_size=3,stride=1,padding=1,bias=False)
199-
self.bilinear = nn.Upsample(size=(oheight, owidth), mode='bilinear')
199+
self.bilinear = nn.Upsample(size=(oheight, owidth), mode='bilinear', align_corners=True)
200200

201201
# weight init
202202
self.conv2.apply(weights_init)

utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,37 @@
1+
import os
2+
import torch
3+
import shutil
14
import numpy as np
25
import matplotlib.pyplot as plt
36
from PIL import Image
47

58
cmap = plt.cm.viridis
69

10+
def save_checkpoint(state, is_best, epoch, output_directory):
11+
checkpoint_filename = os.path.join(output_directory, 'checkpoint-' + str(epoch) + '.pth.tar')
12+
torch.save(state, checkpoint_filename)
13+
if is_best:
14+
best_filename = os.path.join(output_directory, 'model_best.pth.tar')
15+
shutil.copyfile(checkpoint_filename, best_filename)
16+
if epoch > 0:
17+
prev_checkpoint_filename = os.path.join(output_directory, 'checkpoint-' + str(epoch-1) + '.pth.tar')
18+
if os.path.exists(prev_checkpoint_filename):
19+
os.remove(prev_checkpoint_filename)
20+
21+
def adjust_learning_rate(optimizer, epoch, lr_init):
22+
"""Sets the learning rate to the initial LR decayed by 10 every 5 epochs"""
23+
lr = lr_init * (0.1 ** (epoch // 5))
24+
for param_group in optimizer.param_groups:
25+
param_group['lr'] = lr
26+
27+
def get_output_directory(args):
28+
output_directory = os.path.join('results',
29+
'{}.sparsifier={}.modality={}.arch={}.decoder={}.criterion={}.lr={}.bs={}.pretrained={}'.
30+
format(args.data, args.sparsifier, args.modality, args.arch, \
31+
args.decoder, args.criterion, args.lr, args.batch_size, \
32+
args.pretrained))
33+
return output_directory
34+
735

836
def colored_depthmap(depth, d_min=None, d_max=None):
937
if d_min is None:

0 commit comments

Comments
 (0)