1515from nyu_dataloader import NYUDataset
1616from models import Decoder , ResNet
1717from metrics import AverageMeter , Result
18+ from dense_to_sparse import UniformSampling , SimulatedStereo
1819import criteria
1920import utils
2021
2122model_names = ['resnet18' , 'resnet50' ]
2223loss_names = ['l1' , 'l2' ]
23- data_names = ['NYUDataset' ]
24+ data_names = ['nyudepthv2' ]
25+ sparsifier_names = [x .name for x in [UniformSampling , SimulatedStereo ]]
2426decoder_names = Decoder .names
2527modality_names = NYUDataset .modality_names
2628
4648 ' (default: rgb)' )
4749parser .add_argument ('-s' , '--num-samples' , default = 0 , type = int , metavar = 'N' ,
4850 help = 'number of sparse depth samples (default: 0)' )
51+ parser .add_argument ('--max-depth' , default = - 1.0 , type = float , metavar = 'D' ,
52+ help = 'cut-off depth of sparsifier, negative values means infinity (default: inf [m])' )
53+ parser .add_argument ('--sparsifier' , metavar = 'SPARSIFIER' , default = UniformSampling .name ,
54+ choices = sparsifier_names ,
55+ help = 'sparsifier: ' +
56+ ' | ' .join (sparsifier_names ) +
57+ ' (default: ' + UniformSampling .name + ')' )
4958parser .add_argument ('--decoder' , '-d' , metavar = 'DECODER' , default = 'deconv2' ,
5059 choices = decoder_names ,
5160 help = 'decoder: ' +
8897def main ():
8998 global args , best_result , output_directory , train_csv , test_csv
9099 args = parser .parse_args ()
91- args .data = os .path .join ('data' , args .data )
92100 if args .modality == 'rgb' and args .num_samples != 0 :
93101 print ("number of samples is forced to be 0 when input modality is rgb" )
94102 args .num_samples = 0
95-
103+ if args .modality == 'rgb' and args .max_depth != 0.0 :
104+ print ("max depth is forced to be 0.0 when input modality is rgb/rgbd" )
105+ args .max_depth = 0.0
106+
107+ sparsifier = None
108+ max_depth = args .max_depth if args .max_depth >= 0.0 else np .inf
109+ if args .sparsifier == UniformSampling .name :
110+ sparsifier = UniformSampling (num_samples = args .num_samples , max_depth = max_depth )
111+ elif args .sparsifier == SimulatedStereo .name :
112+ sparsifier = SimulatedStereo (num_samples = args .num_samples , max_depth = max_depth )
113+
96114 # create results folder, if not already exists
97115 output_directory = os .path .join ('results' ,
98- 'NYUDataset.modality ={}.nsample ={}.arch={}.decoder={}.criterion={}.lr={}.bs={}' .
99- format (args .modality , args .num_samples , args .arch , args .decoder , args .criterion , args .lr , args .batch_size ))
116+ '{}.sparsifier ={}.modality ={}.arch={}.decoder={}.criterion={}.lr={}.bs={}' .
117+ format (args .data , sparsifier , args .modality , args .arch , args .decoder , args .criterion , args .lr , args .batch_size ))
100118 if not os .path .exists (output_directory ):
101119 os .makedirs (output_directory )
102120 train_csv = os .path .join (output_directory , 'train.csv' )
@@ -112,19 +130,19 @@ def main():
112130
113131 # Data loading code
114132 print ("=> creating data loaders ..." )
115- traindir = os .path .join (args .data , 'train' )
116- valdir = os .path .join (args .data , 'val' )
133+ traindir = os .path .join ('data' , args .data , 'train' )
134+ valdir = os .path .join ('data' , args .data , 'val' )
117135
118- train_dataset = NYUDataset (traindir , type = 'train' ,
119- modality = args .modality , num_samples = args . num_samples )
136+ train_dataset = NYUDataset (traindir , type = 'train' ,
137+ modality = args .modality , sparsifier = sparsifier )
120138 train_loader = torch .utils .data .DataLoader (
121139 train_dataset , batch_size = args .batch_size , shuffle = True ,
122140 num_workers = args .workers , pin_memory = True , sampler = None )
123141
124142 # set batch size to be 1 for validation
125- val_dataset = NYUDataset (valdir , type = 'val' ,
126- modality = args .modality , num_samples = args . num_samples )
127- val_loader = torch .utils .data .DataLoader (val_dataset ,
143+ val_dataset = NYUDataset (valdir , type = 'val' ,
144+ modality = args .modality , sparsifier = sparsifier )
145+ val_loader = torch .utils .data .DataLoader (val_dataset ,
128146 batch_size = 1 , shuffle = False , num_workers = args .workers , pin_memory = True )
129147
130148 print ("=> data loaders created." )
@@ -306,11 +324,18 @@ def validate(val_loader, model, epoch, write_to_file=True):
306324 rgb = input
307325 elif args .modality == 'rgbd' :
308326 rgb = input [:,:3 ,:,:]
327+ depth = input [:,3 :,:,:]
309328
310329 if i == 0 :
311- img_merge = utils .merge_into_row (rgb , target , depth_pred )
330+ if args .modality == 'rgbd' :
331+ img_merge = utils .merge_into_row_with_gt (rgb , depth , target , depth_pred )
332+ else :
333+ img_merge = utils .merge_into_row (rgb , target , depth_pred )
312334 elif (i < 8 * skip ) and (i % skip == 0 ):
313- row = utils .merge_into_row (rgb , target , depth_pred )
335+ if args .modality == 'rgbd' :
336+ row = utils .merge_into_row_with_gt (rgb , depth , target , depth_pred )
337+ else :
338+ row = utils .merge_into_row (rgb , target , depth_pred )
314339 img_merge = utils .add_row (img_merge , row )
315340 elif i == 8 * skip :
316341 filename = output_directory + '/comparison_' + str (epoch ) + '.png'
0 commit comments