Skip to content

Commit 978b3b1

Browse files
author
Fangchang Ma
committed
minor updates to main.py
1 parent f16e13f commit 978b3b1

File tree

1 file changed

+24
-21
lines changed

1 file changed

+24
-21
lines changed

main.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
help='number of total epochs to run (default: 15)')
6868
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
6969
help='manual epoch number (useful on restarts)')
70-
parser.add_argument('-c', '--criterion', metavar='LOSS', default='l1',
70+
parser.add_argument('-c', '--criterion', metavar='LOSS', default='l1',
7171
choices=loss_names,
7272
help='loss function: ' +
7373
' | '.join(loss_names) +
@@ -86,24 +86,27 @@
8686
help='path to latest checkpoint (default: none)')
8787
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
8888
help='evaluate model on validation set')
89-
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
89+
parser.add_argument('--pretrained', dest='pretrained', choices=['True', 'False'],
9090
default=True, help='use ImageNet pre-trained weights (default: True)')
9191

92-
fieldnames = ['mse', 'rmse', 'absrel', 'lg10', 'mae',
93-
'delta1', 'delta2', 'delta3',
92+
args = parser.parse_args()
93+
if args.modality == 'rgb' and args.num_samples != 0:
94+
print("number of samples is forced to be 0 when input modality is rgb")
95+
args.num_samples = 0
96+
if args.modality == 'rgb' and args.max_depth != 0.0:
97+
print("max depth is forced to be 0.0 when input modality is rgb/rgbd")
98+
args.max_depth = 0.0
99+
args.pretrained = (args.pretrained == "True")
100+
print(args)
101+
102+
fieldnames = ['mse', 'rmse', 'absrel', 'lg10', 'mae',
103+
'delta1', 'delta2', 'delta3',
94104
'data_time', 'gpu_time']
95105
best_result = Result()
96106
best_result.set_to_worst()
97107

98108
def main():
99109
global args, best_result, output_directory, train_csv, test_csv
100-
args = parser.parse_args()
101-
if args.modality == 'rgb' and args.num_samples != 0:
102-
print("number of samples is forced to be 0 when input modality is rgb")
103-
args.num_samples = 0
104-
if args.modality == 'rgb' and args.max_depth != 0.0:
105-
print("max depth is forced to be 0.0 when input modality is rgb/rgbd")
106-
args.max_depth = 0.0
107110

108111
sparsifier = None
109112
max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf
@@ -121,7 +124,7 @@ def main():
121124
train_csv = os.path.join(output_directory, 'train.csv')
122125
test_csv = os.path.join(output_directory, 'test.csv')
123126
best_txt = os.path.join(output_directory, 'best.txt')
124-
127+
125128
# define loss function (criterion) and optimizer
126129
if args.criterion == 'l2':
127130
criterion = criteria.MaskedMSELoss().cuda()
@@ -195,10 +198,10 @@ def main():
195198
weight_decay=args.weight_decay)
196199

197200
# create new csv files with only header
198-
with open(train_csv, 'w') as csvfile:
201+
with open(train_csv, 'w') as csvfile:
199202
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
200203
writer.writeheader()
201-
with open(test_csv, 'w') as csvfile:
204+
with open(test_csv, 'w') as csvfile:
202205
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
203206
writer.writeheader()
204207

@@ -211,7 +214,7 @@ def main():
211214
adjust_learning_rate(optimizer, epoch)
212215

213216
# train for one epoch
214-
# train(train_loader, model, criterion, optimizer, epoch)
217+
train(train_loader, model, criterion, optimizer, epoch)
215218

216219
# evaluate on validation set
217220
result, img_merge = validate(val_loader, model, epoch)
@@ -226,7 +229,7 @@ def main():
226229
if img_merge is not None:
227230
img_filename = output_directory + '/comparison_best.png'
228231
utils.save_image(img_merge, img_filename)
229-
232+
230233
save_checkpoint({
231234
'epoch': epoch,
232235
'arch': args.arch,
@@ -278,14 +281,14 @@ def train(train_loader, model, criterion, optimizer, epoch):
278281
'Delta1={result.delta1:.3f}({average.delta1:.3f}) '
279282
'REL={result.absrel:.3f}({average.absrel:.3f}) '
280283
'Lg10={result.lg10:.3f}({average.lg10:.3f}) '.format(
281-
epoch, i+1, len(train_loader), data_time=data_time,
284+
epoch, i+1, len(train_loader), data_time=data_time,
282285
gpu_time=gpu_time, result=result, average=average_meter.average()))
283286

284287
avg = average_meter.average()
285-
with open(train_csv, 'a') as csvfile:
288+
with open(train_csv, 'a') as csvfile:
286289
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
287290
writer.writerow({'mse': avg.mse, 'rmse': avg.rmse, 'absrel': avg.absrel, 'lg10': avg.lg10,
288-
'mae': avg.mae, 'delta1': avg.delta1, 'delta2': avg.delta2, 'delta3': avg.delta3,
291+
'mae': avg.mae, 'delta1': avg.delta1, 'delta2': avg.delta2, 'delta3': avg.delta3,
289292
'gpu_time': avg.gpu_time, 'data_time': avg.data_time})
290293

291294

@@ -364,10 +367,10 @@ def validate(val_loader, model, epoch, write_to_file=True):
364367
average=avg, time=avg.gpu_time))
365368

366369
if write_to_file:
367-
with open(test_csv, 'a') as csvfile:
370+
with open(test_csv, 'a') as csvfile:
368371
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
369372
writer.writerow({'mse': avg.mse, 'rmse': avg.rmse, 'absrel': avg.absrel, 'lg10': avg.lg10,
370-
'mae': avg.mae, 'delta1': avg.delta1, 'delta2': avg.delta2, 'delta3': avg.delta3,
373+
'mae': avg.mae, 'delta1': avg.delta1, 'delta2': avg.delta2, 'delta3': avg.delta3,
371374
'data_time': avg.data_time, 'gpu_time': avg.gpu_time})
372375

373376
return avg, img_merge

0 commit comments

Comments
 (0)