Skip to content

Commit ddba1cd

Browse files
author
Fangchang Ma
committed
fixed bug with image saving when input modality is not rgb
1 parent 7b2236c commit ddba1cd

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

main.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def main():
8989
global args, best_result, output_directory, train_csv, test_csv
9090
args = parser.parse_args()
9191
args.data = os.path.join('data', args.data)
92+
if args.modality == 'rgb' and args.num_samples != 0:
93+
print("number of samples is forced to be 0 when input modality is rgb")
94+
args.num_samples = 0
9295

9396
# create results folder, if not already exists
9497
output_directory = os.path.join('results',
@@ -201,8 +204,9 @@ def main():
201204
with open(best_txt, 'w') as txtfile:
202205
txtfile.write("epoch={}\nmse={:.3f}\nrmse={:.3f}\nabsrel={:.3f}\nlg10={:.3f}\nmae={:.3f}\ndelta1={:.3f}\nt_gpu={:.4f}\n".
203206
format(epoch, result.mse, result.rmse, result.absrel, result.lg10, result.mae, result.delta1, result.gpu_time))
204-
img_filename = output_directory + '/comparison_best.png'
205-
utils.save_image(img_merge, img_filename)
207+
if img_merge is not None:
208+
img_filename = output_directory + '/comparison_best.png'
209+
utils.save_image(img_merge, img_filename)
206210

207211
save_checkpoint({
208212
'epoch': epoch,
@@ -295,14 +299,22 @@ def validate(val_loader, model, epoch, write_to_file=True):
295299

296300
# save 8 images for visualization
297301
skip = 50
298-
if i == 0:
299-
img_merge = utils.merge_into_row(input, target, depth_pred)
300-
elif (i < 8*skip) and (i % skip == 0):
301-
row = utils.merge_into_row(input, target, depth_pred)
302-
img_merge = utils.add_row(img_merge, row)
303-
elif i == 8*skip:
304-
filename = output_directory + '/comparison_' + str(epoch) + '.png'
305-
utils.save_image(img_merge, filename)
302+
if args.modality == 'd':
303+
img_merge = None
304+
else:
305+
if args.modality == 'rgb':
306+
rgb = input
307+
elif args.modality == 'rgbd':
308+
rgb = input[:,:3,:,:]
309+
310+
if i == 0:
311+
img_merge = utils.merge_into_row(rgb, target, depth_pred)
312+
elif (i < 8*skip) and (i % skip == 0):
313+
row = utils.merge_into_row(rgb, target, depth_pred)
314+
img_merge = utils.add_row(img_merge, row)
315+
elif i == 8*skip:
316+
filename = output_directory + '/comparison_' + str(epoch) + '.png'
317+
utils.save_image(img_merge, filename)
306318

307319
if (i+1) % args.print_freq == 0:
308320
print('Test: [{0}/{1}]\t'
@@ -340,7 +352,7 @@ def save_checkpoint(state, is_best, epoch):
340352
if is_best:
341353
best_filename = os.path.join(output_directory, 'model_best.pth.tar')
342354
shutil.copyfile(checkpoint_filename, best_filename)
343-
if epoch > 1:
355+
if epoch > 0:
344356
prev_checkpoint_filename = os.path.join(output_directory, 'checkpoint-' + str(epoch-1) + '.pth.tar')
345357
if os.path.exists(prev_checkpoint_filename):
346358
os.remove(prev_checkpoint_filename)

0 commit comments

Comments
 (0)