@@ -89,6 +89,9 @@ def main():
8989 global args , best_result , output_directory , train_csv , test_csv
9090 args = parser .parse_args ()
9191 args .data = os .path .join ('data' , args .data )
92+ if args .modality == 'rgb' and args .num_samples != 0 :
93+ print ("number of samples is forced to be 0 when input modality is rgb" )
94+ args .num_samples = 0
9295
9396 # create results folder, if not already exists
9497 output_directory = os .path .join ('results' ,
@@ -201,8 +204,9 @@ def main():
201204 with open (best_txt , 'w' ) as txtfile :
202205 txtfile .write ("epoch={}\n mse={:.3f}\n rmse={:.3f}\n absrel={:.3f}\n lg10={:.3f}\n mae={:.3f}\n delta1={:.3f}\n t_gpu={:.4f}\n " .
203206 format (epoch , result .mse , result .rmse , result .absrel , result .lg10 , result .mae , result .delta1 , result .gpu_time ))
204- img_filename = output_directory + '/comparison_best.png'
205- utils .save_image (img_merge , img_filename )
207+ if img_merge is not None :
208+ img_filename = output_directory + '/comparison_best.png'
209+ utils .save_image (img_merge , img_filename )
206210
207211 save_checkpoint ({
208212 'epoch' : epoch ,
@@ -295,14 +299,22 @@ def validate(val_loader, model, epoch, write_to_file=True):
295299
296300 # save 8 images for visualization
297301 skip = 50
298- if i == 0 :
299- img_merge = utils .merge_into_row (input , target , depth_pred )
300- elif (i < 8 * skip ) and (i % skip == 0 ):
301- row = utils .merge_into_row (input , target , depth_pred )
302- img_merge = utils .add_row (img_merge , row )
303- elif i == 8 * skip :
304- filename = output_directory + '/comparison_' + str (epoch ) + '.png'
305- utils .save_image (img_merge , filename )
302+ if args .modality == 'd' :
303+ img_merge = None
304+ else :
305+ if args .modality == 'rgb' :
306+ rgb = input
307+ elif args .modality == 'rgbd' :
308+ rgb = input [:,:3 ,:,:]
309+
310+ if i == 0 :
311+ img_merge = utils .merge_into_row (rgb , target , depth_pred )
312+ elif (i < 8 * skip ) and (i % skip == 0 ):
313+ row = utils .merge_into_row (rgb , target , depth_pred )
314+ img_merge = utils .add_row (img_merge , row )
315+ elif i == 8 * skip :
316+ filename = output_directory + '/comparison_' + str (epoch ) + '.png'
317+ utils .save_image (img_merge , filename )
306318
307319 if (i + 1 ) % args .print_freq == 0 :
308320 print ('Test: [{0}/{1}]\t '
@@ -340,7 +352,7 @@ def save_checkpoint(state, is_best, epoch):
340352 if is_best :
341353 best_filename = os .path .join (output_directory , 'model_best.pth.tar' )
342354 shutil .copyfile (checkpoint_filename , best_filename )
343- if epoch > 1 :
355+ if epoch > 0 :
344356 prev_checkpoint_filename = os .path .join (output_directory , 'checkpoint-' + str (epoch - 1 ) + '.pth.tar' )
345357 if os .path .exists (prev_checkpoint_filename ):
346358 os .remove (prev_checkpoint_filename )
0 commit comments