44import time
55import sys
66import csv
7+ import numpy as np
78
89import torch
910import torch .nn as nn
1516from nyu_dataloader import NYUDataset
1617from models import Decoder , ResNet
1718from metrics import AverageMeter , Result
19+ from dense_to_sparse import UniformSampling , SimulatedStereo
1820import criteria
1921import utils
2022
2123model_names = ['resnet18' , 'resnet50' ]
2224loss_names = ['l1' , 'l2' ]
23- data_names = ['NYUDataset' ]
25+ data_names = ['nyudepthv2' ]
26+ sparsifier_names = [x .name for x in [UniformSampling , SimulatedStereo ]]
2427decoder_names = Decoder .names
2528modality_names = NYUDataset .modality_names
2629
4649 ' (default: rgb)' )
4750parser .add_argument ('-s' , '--num-samples' , default = 0 , type = int , metavar = 'N' ,
4851 help = 'number of sparse depth samples (default: 0)' )
52+ parser .add_argument ('--max-depth' , default = - 1.0 , type = float , metavar = 'D' ,
53+ help = 'cut-off depth of sparsifier, negative values means infinity (default: inf [m])' )
54+ parser .add_argument ('--sparsifier' , metavar = 'SPARSIFIER' , default = UniformSampling .name ,
55+ choices = sparsifier_names ,
56+ help = 'sparsifier: ' +
57+ ' | ' .join (sparsifier_names ) +
58+ ' (default: ' + UniformSampling .name + ')' )
4959parser .add_argument ('--decoder' , '-d' , metavar = 'DECODER' , default = 'deconv2' ,
5060 choices = decoder_names ,
5161 help = 'decoder: ' +
8898def main ():
8999 global args , best_result , output_directory , train_csv , test_csv
90100 args = parser .parse_args ()
91- args .data = os .path .join ('data' , args .data )
92101 if args .modality == 'rgb' and args .num_samples != 0 :
93102 print ("number of samples is forced to be 0 when input modality is rgb" )
94103 args .num_samples = 0
95-
104+ if args .modality == 'rgb' and args .max_depth != 0.0 :
105+ print ("max depth is forced to be 0.0 when input modality is rgb/rgbd" )
106+ args .max_depth = 0.0
107+
108+ sparsifier = None
109+ max_depth = args .max_depth if args .max_depth >= 0.0 else np .inf
110+ if args .sparsifier == UniformSampling .name :
111+ sparsifier = UniformSampling (num_samples = args .num_samples , max_depth = max_depth )
112+ elif args .sparsifier == SimulatedStereo .name :
113+ sparsifier = SimulatedStereo (num_samples = args .num_samples , max_depth = max_depth )
114+
96115 # create results folder, if not already exists
97116 output_directory = os .path .join ('results' ,
98- 'NYUDataset.modality ={}.nsample ={}.arch={}.decoder={}.criterion={}.lr={}.bs={}' .
99- format (args .modality , args .num_samples , args .arch , args .decoder , args .criterion , args .lr , args .batch_size ))
117+ '{}.sparsifier ={}.modality ={}.arch={}.decoder={}.criterion={}.lr={}.bs={}' .
118+ format (args .data , sparsifier , args .modality , args .arch , args .decoder , args .criterion , args .lr , args .batch_size ))
100119 if not os .path .exists (output_directory ):
101120 os .makedirs (output_directory )
102121 train_csv = os .path .join (output_directory , 'train.csv' )
@@ -112,19 +131,19 @@ def main():
112131
113132 # Data loading code
114133 print ("=> creating data loaders ..." )
115- traindir = os .path .join (args .data , 'train' )
116- valdir = os .path .join (args .data , 'val' )
134+ traindir = os .path .join ('data' , args .data , 'train' )
135+ valdir = os .path .join ('data' , args .data , 'val' )
117136
118- train_dataset = NYUDataset (traindir , type = 'train' ,
119- modality = args .modality , num_samples = args . num_samples )
137+ train_dataset = NYUDataset (traindir , type = 'train' ,
138+ modality = args .modality , sparsifier = sparsifier )
120139 train_loader = torch .utils .data .DataLoader (
121140 train_dataset , batch_size = args .batch_size , shuffle = True ,
122141 num_workers = args .workers , pin_memory = True , sampler = None )
123142
124143 # set batch size to be 1 for validation
125- val_dataset = NYUDataset (valdir , type = 'val' ,
126- modality = args .modality , num_samples = args . num_samples )
127- val_loader = torch .utils .data .DataLoader (val_dataset ,
144+ val_dataset = NYUDataset (valdir , type = 'val' ,
145+ modality = args .modality , sparsifier = sparsifier )
146+ val_loader = torch .utils .data .DataLoader (val_dataset ,
128147 batch_size = 1 , shuffle = False , num_workers = args .workers , pin_memory = True )
129148
130149 print ("=> data loaders created." )
@@ -192,7 +211,7 @@ def main():
192211 adjust_learning_rate (optimizer , epoch )
193212
194213 # train for one epoch
195- train (train_loader , model , criterion , optimizer , epoch )
214+ # train(train_loader, model, criterion, optimizer, epoch)
196215
197216 # evaluate on validation set
198217 result , img_merge = validate (val_loader , model , epoch )
@@ -306,11 +325,18 @@ def validate(val_loader, model, epoch, write_to_file=True):
306325 rgb = input
307326 elif args .modality == 'rgbd' :
308327 rgb = input [:,:3 ,:,:]
328+ depth = input [:,3 :,:,:]
309329
310330 if i == 0 :
311- img_merge = utils .merge_into_row (rgb , target , depth_pred )
331+ if args .modality == 'rgbd' :
332+ img_merge = utils .merge_into_row_with_gt (rgb , depth , target , depth_pred )
333+ else :
334+ img_merge = utils .merge_into_row (rgb , target , depth_pred )
312335 elif (i < 8 * skip ) and (i % skip == 0 ):
313- row = utils .merge_into_row (rgb , target , depth_pred )
336+ if args .modality == 'rgbd' :
337+ row = utils .merge_into_row_with_gt (rgb , depth , target , depth_pred )
338+ else :
339+ row = utils .merge_into_row (rgb , target , depth_pred )
314340 img_merge = utils .add_row (img_merge , row )
315341 elif i == 8 * skip :
316342 filename = output_directory + '/comparison_' + str (epoch ) + '.png'
0 commit comments