11import argparse
22import os
3- import shutil
43import time
54import sys
65import csv
8685 help = 'path to latest checkpoint (default: none)' )
8786parser .add_argument ('-e' , '--evaluate' , dest = 'evaluate' , action = 'store_true' ,
8887 help = 'evaluate model on validation set' )
89- parser .add_argument ('--pretrained' , dest = 'pretrained' , choices = ['True' , 'False' ],
90- default = True , help = 'use ImageNet pre-trained weights (default: True)' )
88+ parser .add_argument ('--no-pretrain' , dest = 'pretrained' , action = 'store_false' ,
89+ help = 'not to use ImageNet pre-trained weights' )
90+ parser .set_defaults (pretrained = True )
9191
9292args = parser .parse_args ()
9393if args .modality == 'rgb' and args .num_samples != 0 :
9696if args .modality == 'rgb' and args .max_depth != 0.0 :
9797 print ("max depth is forced to be 0.0 when input modality is rgb/rgbd" )
9898 args .max_depth = 0.0
99- args .pretrained = (args .pretrained == "True" )
10099print (args )
101100
102101fieldnames = ['mse' , 'rmse' , 'absrel' , 'lg10' , 'mae' ,
@@ -116,9 +115,7 @@ def main():
116115 sparsifier = SimulatedStereo (num_samples = args .num_samples , max_depth = max_depth )
117116
118117 # create results folder, if not already exists
119- output_directory = os .path .join ('results' ,
120- '{}.sparsifier={}.modality={}.arch={}.decoder={}.criterion={}.lr={}.bs={}' .
121- format (args .data , sparsifier , args .modality , args .arch , args .decoder , args .criterion , args .lr , args .batch_size ))
118+ output_directory = utils .get_output_directory (args )
122119 if not os .path .exists (output_directory ):
123120 os .makedirs (output_directory )
124121 train_csv = os .path .join (output_directory , 'train.csv' )
@@ -154,31 +151,28 @@ def main():
154151 # evaluation mode
155152 if args .evaluate :
156153 best_model_filename = os .path .join (output_directory , 'model_best.pth.tar' )
157- if os .path .isfile (best_model_filename ):
158- print ("=> loading best model '{}'" .format (best_model_filename ))
159- checkpoint = torch .load (best_model_filename )
160- args .start_epoch = checkpoint ['epoch' ]
161- best_result = checkpoint ['best_result' ]
162- model = checkpoint ['model' ]
163- print ("=> loaded best model (epoch {})" .format (checkpoint ['epoch' ]))
164- else :
165- print ("=> no best model found at '{}'" .format (best_model_filename ))
154+ assert os .path .isfile (best_model_filename ), \
155+ "=> no best model found at '{}'" .format (best_model_filename )
156+ print ("=> loading best model '{}'" .format (best_model_filename ))
157+ checkpoint = torch .load (best_model_filename )
158+ args .start_epoch = checkpoint ['epoch' ]
159+ best_result = checkpoint ['best_result' ]
160+ model = checkpoint ['model' ]
161+ print ("=> loaded best model (epoch {})" .format (checkpoint ['epoch' ]))
166162 validate (val_loader , model , checkpoint ['epoch' ], write_to_file = False )
167163 return
168164
169165 # optionally resume from a checkpoint
170166 elif args .resume :
171- if os .path .isfile (args .resume ):
172- print ("=> loading checkpoint '{}'" .format (args .resume ))
173- checkpoint = torch .load (args .resume )
174- args .start_epoch = checkpoint ['epoch' ]+ 1
175- best_result = checkpoint ['best_result' ]
176- model = checkpoint ['model' ]
177- optimizer = checkpoint ['optimizer' ]
178- print ("=> loaded checkpoint (epoch {})" .format (checkpoint ['epoch' ]))
179- else :
180- print ("=> no checkpoint found at '{}'" .format (args .resume ))
181- return
167+ assert os .path .isfile (args .resume ), \
168+ "=> no checkpoint found at '{}'" .format (args .resume )
169+ print ("=> loading checkpoint '{}'" .format (args .resume ))
170+ checkpoint = torch .load (args .resume )
171+ args .start_epoch = checkpoint ['epoch' ]+ 1
172+ best_result = checkpoint ['best_result' ]
173+ model = checkpoint ['model' ]
174+ optimizer = checkpoint ['optimizer' ]
175+ print ("=> loaded checkpoint (epoch {})" .format (checkpoint ['epoch' ]))
182176
183177 # create new model
184178 else :
@@ -211,7 +205,7 @@ def main():
211205 print ("=> model transferred to GPU." )
212206
213207 for epoch in range (args .start_epoch , args .epochs ):
214- adjust_learning_rate (optimizer , epoch )
208+ utils . adjust_learning_rate (optimizer , epoch , args . lr )
215209
216210 # train for one epoch
217211 train (train_loader , model , criterion , optimizer , epoch )
@@ -230,13 +224,14 @@ def main():
230224 img_filename = output_directory + '/comparison_best.png'
231225 utils .save_image (img_merge , img_filename )
232226
233- save_checkpoint ({
227+ utils .save_checkpoint ({
228+ 'args' : args ,
234229 'epoch' : epoch ,
235230 'arch' : args .arch ,
236231 'model' : model ,
237232 'best_result' : best_result ,
238233 'optimizer' : optimizer ,
239- }, is_best , epoch )
234+ }, is_best , epoch , output_directory )
240235
241236
242237def train (train_loader , model , criterion , optimizer , epoch ):
@@ -249,35 +244,32 @@ def train(train_loader, model, criterion, optimizer, epoch):
249244 for i , (input , target ) in enumerate (train_loader ):
250245
251246 input , target = input .cuda (), target .cuda ()
252- input_var = torch .autograd .Variable (input )
253- target_var = torch .autograd .Variable (target )
254- torch .cuda .synchronize ()
247+ # torch.cuda.synchronize()
255248 data_time = time .time () - end
256249
257- # compute depth_pred
250+ # compute pred
258251 end = time .time ()
259- depth_pred = model (input_var )
260- loss = criterion (depth_pred , target_var )
252+ pred = model (input )
253+ loss = criterion (pred , target )
261254 optimizer .zero_grad ()
262255 loss .backward () # compute gradient and do SGD step
263256 optimizer .step ()
264- torch .cuda .synchronize ()
257+ # torch.cuda.synchronize()
265258 gpu_time = time .time () - end
266259
267260 # measure accuracy and record loss
268261 result = Result ()
269- output1 = torch .index_select (depth_pred .data , 1 , torch .cuda .LongTensor ([0 ]))
270- result .evaluate (output1 , target )
262+ result .evaluate (pred .data , target .data )
271263 average_meter .update (result , gpu_time , data_time , input .size (0 ))
272264 end = time .time ()
273265
274266 if (i + 1 ) % args .print_freq == 0 :
275267 print ('=> output: {}' .format (output_directory ))
276268 print ('Train Epoch: {0} [{1}/{2}]\t '
277269 't_Data={data_time:.3f}({average.data_time:.3f}) '
278- 't_GPU={gpu_time:.3f}({average.gpu_time:.3f}) '
270+ 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n \t '
279271 'RMSE={result.rmse:.2f}({average.rmse:.2f}) '
280- 'MAE={result.mae:.2f}({average.mae:.2f}) '
272+ 'MAE={result.mae:.2f}({average.mae:.2f})\n \t '
281273 'Delta1={result.delta1:.3f}({average.delta1:.3f}) '
282274 'REL={result.absrel:.3f}({average.absrel:.3f}) '
283275 'Lg10={result.lg10:.3f}({average.lg10:.3f}) ' .format (
@@ -301,21 +293,19 @@ def validate(val_loader, model, epoch, write_to_file=True):
301293 end = time .time ()
302294 for i , (input , target ) in enumerate (val_loader ):
303295 input , target = input .cuda (), target .cuda ()
304- input_var = torch .autograd .Variable (input )
305- target_var = torch .autograd .Variable (target )
306- torch .cuda .synchronize ()
296+ # torch.cuda.synchronize()
307297 data_time = time .time () - end
308298
309299 # compute output
310300 end = time .time ()
311- depth_pred = model (input_var )
312- torch .cuda .synchronize ()
301+ with torch .no_grad ():
302+ pred = model (input )
303+ # torch.cuda.synchronize()
313304 gpu_time = time .time () - end
314305
315306 # measure accuracy and record loss
316307 result = Result ()
317- output1 = torch .index_select (depth_pred .data , 1 , torch .cuda .LongTensor ([0 ]))
318- result .evaluate (output1 , target )
308+ result .evaluate (pred .data , target .data )
319309 average_meter .update (result , gpu_time , data_time , input .size (0 ))
320310 end = time .time ()
321311
@@ -332,24 +322,24 @@ def validate(val_loader, model, epoch, write_to_file=True):
332322
333323 if i == 0 :
334324 if args .modality == 'rgbd' :
335- img_merge = utils .merge_into_row_with_gt (rgb , depth , target , depth_pred )
325+ img_merge = utils .merge_into_row_with_gt (rgb , depth , target , pred )
336326 else :
337- img_merge = utils .merge_into_row (rgb , target , depth_pred )
327+ img_merge = utils .merge_into_row (rgb , target , pred )
338328 elif (i < 8 * skip ) and (i % skip == 0 ):
339329 if args .modality == 'rgbd' :
340- row = utils .merge_into_row_with_gt (rgb , depth , target , depth_pred )
330+ row = utils .merge_into_row_with_gt (rgb , depth , target , pred )
341331 else :
342- row = utils .merge_into_row (rgb , target , depth_pred )
332+ row = utils .merge_into_row (rgb , target , pred )
343333 img_merge = utils .add_row (img_merge , row )
344334 elif i == 8 * skip :
345335 filename = output_directory + '/comparison_' + str (epoch ) + '.png'
346336 utils .save_image (img_merge , filename )
347337
348338 if (i + 1 ) % args .print_freq == 0 :
349339 print ('Test: [{0}/{1}]\t '
350- 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\t '
340+ 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n \ t '
351341 'RMSE={result.rmse:.2f}({average.rmse:.2f}) '
352- 'MAE={result.mae:.2f}({average.mae:.2f}) '
342+ 'MAE={result.mae:.2f}({average.mae:.2f})\n \t '
353343 'Delta1={result.delta1:.3f}({average.delta1:.3f}) '
354344 'REL={result.absrel:.3f}({average.absrel:.3f}) '
355345 'Lg10={result.lg10:.3f}({average.lg10:.3f}) ' .format (
@@ -375,22 +365,5 @@ def validate(val_loader, model, epoch, write_to_file=True):
375365
376366 return avg , img_merge
377367
378- def save_checkpoint (state , is_best , epoch ):
379- checkpoint_filename = os .path .join (output_directory , 'checkpoint-' + str (epoch ) + '.pth.tar' )
380- torch .save (state , checkpoint_filename )
381- if is_best :
382- best_filename = os .path .join (output_directory , 'model_best.pth.tar' )
383- shutil .copyfile (checkpoint_filename , best_filename )
384- if epoch > 0 :
385- prev_checkpoint_filename = os .path .join (output_directory , 'checkpoint-' + str (epoch - 1 ) + '.pth.tar' )
386- if os .path .exists (prev_checkpoint_filename ):
387- os .remove (prev_checkpoint_filename )
388-
389- def adjust_learning_rate (optimizer , epoch ):
390- """Sets the learning rate to the initial LR decayed by 10 every 5 epochs"""
391- lr = args .lr * (0.1 ** (epoch // 5 ))
392- for param_group in optimizer .param_groups :
393- param_group ['lr' ] = lr
394-
395368if __name__ == '__main__' :
396369 main ()
0 commit comments