6767 help = 'number of total epochs to run (default: 15)' )
6868parser .add_argument ('--start-epoch' , default = 0 , type = int , metavar = 'N' ,
6969 help = 'manual epoch number (useful on restarts)' )
70- parser .add_argument ('-c' , '--criterion' , metavar = 'LOSS' , default = 'l1' ,
70+ parser .add_argument ('-c' , '--criterion' , metavar = 'LOSS' , default = 'l1' ,
7171 choices = loss_names ,
7272 help = 'loss function: ' +
7373 ' | ' .join (loss_names ) +
8686 help = 'path to latest checkpoint (default: none)' )
8787parser .add_argument ('-e' , '--evaluate' , dest = 'evaluate' , action = 'store_true' ,
8888 help = 'evaluate model on validation set' )
89- parser .add_argument ('--pretrained' , dest = 'pretrained' , action = 'store_true' ,
89+ parser .add_argument ('--pretrained' , dest = 'pretrained' , choices = [ 'True' , 'False' ] ,
9090 default = True , help = 'use ImageNet pre-trained weights (default: True)' )
9191
92- fieldnames = ['mse' , 'rmse' , 'absrel' , 'lg10' , 'mae' ,
93- 'delta1' , 'delta2' , 'delta3' ,
92+ args = parser .parse_args ()
93+ if args .modality == 'rgb' and args .num_samples != 0 :
94+ print ("number of samples is forced to be 0 when input modality is rgb" )
95+ args .num_samples = 0
96+ if args .modality == 'rgb' and args .max_depth != 0.0 :
97+ print ("max depth is forced to be 0.0 when input modality is rgb/rgbd" )
98+ args .max_depth = 0.0
99+ args .pretrained = (args .pretrained == "True" )
100+ print (args )
101+
102+ fieldnames = ['mse' , 'rmse' , 'absrel' , 'lg10' , 'mae' ,
103+ 'delta1' , 'delta2' , 'delta3' ,
94104 'data_time' , 'gpu_time' ]
95105best_result = Result ()
96106best_result .set_to_worst ()
97107
98108def main ():
99109 global args , best_result , output_directory , train_csv , test_csv
100- args = parser .parse_args ()
101- if args .modality == 'rgb' and args .num_samples != 0 :
102- print ("number of samples is forced to be 0 when input modality is rgb" )
103- args .num_samples = 0
104- if args .modality == 'rgb' and args .max_depth != 0.0 :
105- print ("max depth is forced to be 0.0 when input modality is rgb/rgbd" )
106- args .max_depth = 0.0
107110
108111 sparsifier = None
109112 max_depth = args .max_depth if args .max_depth >= 0.0 else np .inf
@@ -121,7 +124,7 @@ def main():
121124 train_csv = os .path .join (output_directory , 'train.csv' )
122125 test_csv = os .path .join (output_directory , 'test.csv' )
123126 best_txt = os .path .join (output_directory , 'best.txt' )
124-
127+
125128 # define loss function (criterion) and optimizer
126129 if args .criterion == 'l2' :
127130 criterion = criteria .MaskedMSELoss ().cuda ()
@@ -195,10 +198,10 @@ def main():
195198 weight_decay = args .weight_decay )
196199
197200 # create new csv files with only header
198- with open (train_csv , 'w' ) as csvfile :
201+ with open (train_csv , 'w' ) as csvfile :
199202 writer = csv .DictWriter (csvfile , fieldnames = fieldnames )
200203 writer .writeheader ()
201- with open (test_csv , 'w' ) as csvfile :
204+ with open (test_csv , 'w' ) as csvfile :
202205 writer = csv .DictWriter (csvfile , fieldnames = fieldnames )
203206 writer .writeheader ()
204207
@@ -211,7 +214,7 @@ def main():
211214 adjust_learning_rate (optimizer , epoch )
212215
213216 # train for one epoch
214- # train(train_loader, model, criterion, optimizer, epoch)
217+ train (train_loader , model , criterion , optimizer , epoch )
215218
216219 # evaluate on validation set
217220 result , img_merge = validate (val_loader , model , epoch )
@@ -226,7 +229,7 @@ def main():
226229 if img_merge is not None :
227230 img_filename = output_directory + '/comparison_best.png'
228231 utils .save_image (img_merge , img_filename )
229-
232+
230233 save_checkpoint ({
231234 'epoch' : epoch ,
232235 'arch' : args .arch ,
@@ -278,14 +281,14 @@ def train(train_loader, model, criterion, optimizer, epoch):
278281 'Delta1={result.delta1:.3f}({average.delta1:.3f}) '
279282 'REL={result.absrel:.3f}({average.absrel:.3f}) '
280283 'Lg10={result.lg10:.3f}({average.lg10:.3f}) ' .format (
281- epoch , i + 1 , len (train_loader ), data_time = data_time ,
284+ epoch , i + 1 , len (train_loader ), data_time = data_time ,
282285 gpu_time = gpu_time , result = result , average = average_meter .average ()))
283286
284287 avg = average_meter .average ()
285- with open (train_csv , 'a' ) as csvfile :
288+ with open (train_csv , 'a' ) as csvfile :
286289 writer = csv .DictWriter (csvfile , fieldnames = fieldnames )
287290 writer .writerow ({'mse' : avg .mse , 'rmse' : avg .rmse , 'absrel' : avg .absrel , 'lg10' : avg .lg10 ,
288- 'mae' : avg .mae , 'delta1' : avg .delta1 , 'delta2' : avg .delta2 , 'delta3' : avg .delta3 ,
291+ 'mae' : avg .mae , 'delta1' : avg .delta1 , 'delta2' : avg .delta2 , 'delta3' : avg .delta3 ,
289292 'gpu_time' : avg .gpu_time , 'data_time' : avg .data_time })
290293
291294
@@ -364,10 +367,10 @@ def validate(val_loader, model, epoch, write_to_file=True):
364367 average = avg , time = avg .gpu_time ))
365368
366369 if write_to_file :
367- with open (test_csv , 'a' ) as csvfile :
370+ with open (test_csv , 'a' ) as csvfile :
368371 writer = csv .DictWriter (csvfile , fieldnames = fieldnames )
369372 writer .writerow ({'mse' : avg .mse , 'rmse' : avg .rmse , 'absrel' : avg .absrel , 'lg10' : avg .lg10 ,
370- 'mae' : avg .mae , 'delta1' : avg .delta1 , 'delta2' : avg .delta2 , 'delta3' : avg .delta3 ,
373+ 'mae' : avg .mae , 'delta1' : avg .delta1 , 'delta2' : avg .delta2 , 'delta3' : avg .delta3 ,
371374 'data_time' : avg .data_time , 'gpu_time' : avg .gpu_time })
372375
373376 return avg , img_merge
0 commit comments