Skip to content

Commit dcf1937

Browse files
author
Fangchang Ma
committed
fixed bug from previous commit
1 parent c23442f commit dcf1937

File tree

1 file changed

+73
-67
lines changed

1 file changed

+73
-67
lines changed

main.py

Lines changed: 73 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,57 @@
2323
best_result = Result()
2424
best_result.set_to_worst()
2525

26+
def create_data_loaders(args):
27+
# Data loading code
28+
print("=> creating data loaders ...")
29+
traindir = os.path.join('data', args.data, 'train')
30+
valdir = os.path.join('data', args.data, 'val')
31+
train_loader = None
32+
val_loader = None
33+
34+
# sparsifier is a class for generating random sparse depth input from the ground truth
35+
sparsifier = None
36+
max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf
37+
if args.sparsifier == UniformSampling.name:
38+
sparsifier = UniformSampling(num_samples=args.num_samples, max_depth=max_depth)
39+
elif args.sparsifier == SimulatedStereo.name:
40+
sparsifier = SimulatedStereo(num_samples=args.num_samples, max_depth=max_depth)
41+
42+
if args.data == 'nyudepthv2':
43+
from dataloaders.nyu_dataloader import NYUDataset
44+
if not args.evaluate:
45+
train_dataset = NYUDataset(traindir, type='train',
46+
modality=args.modality, sparsifier=sparsifier)
47+
val_dataset = NYUDataset(valdir, type='val',
48+
modality=args.modality, sparsifier=sparsifier)
49+
50+
elif args.data == 'kitti':
51+
from dataloaders.kitti_dataloader import KITTIDataset
52+
if not args.evaluate:
53+
train_dataset = KITTIDataset(traindir, type='train',
54+
modality=args.modality, sparsifier=sparsifier)
55+
val_dataset = KITTIDataset(valdir, type='val',
56+
modality=args.modality, sparsifier=sparsifier)
57+
58+
else:
59+
raise RuntimeError('Dataset not found.' +
60+
'The dataset must be either of nyudepthv2 or kitti.')
61+
62+
# set batch size to be 1 for validation
63+
val_loader = torch.utils.data.DataLoader(val_dataset,
64+
batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)
65+
66+
# put construction of train loader here, for those who are interested in testing only
67+
if not args.evaluate:
68+
train_loader = torch.utils.data.DataLoader(
69+
train_dataset, batch_size=args.batch_size, shuffle=True,
70+
num_workers=args.workers, pin_memory=True, sampler=None,
71+
worker_init_fn=lambda work_id:np.random.seed(work_id))
72+
# worker_init_fn ensures different sampling patterns for each data loading thread
73+
74+
print("=> data loaders created.")
75+
return train_loader, val_loader
76+
2677
def main():
2778
global args, best_result, output_directory, train_csv, test_csv
2879

@@ -33,12 +84,16 @@ def main():
3384
"=> no best model found at '{}'".format(args.evaluate)
3485
print("=> loading best model '{}'".format(args.evaluate))
3586
checkpoint = torch.load(args.evaluate)
87+
output_directory = os.path.dirname(args.evaluate)
3688
args = checkpoint['args']
37-
args.evaluate = True
3889
start_epoch = checkpoint['epoch'] + 1
3990
best_result = checkpoint['best_result']
4091
model = checkpoint['model']
4192
print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
93+
_, val_loader = create_data_loaders(args)
94+
args.evaluate = True
95+
validate(val_loader, model, checkpoint['epoch'], write_to_file=False)
96+
return
4297

4398
# optionally resume from a checkpoint
4499
elif args.resume:
@@ -51,93 +106,35 @@ def main():
51106
best_result = checkpoint['best_result']
52107
model = checkpoint['model']
53108
optimizer = checkpoint['optimizer']
54-
output_directory, _ = os.path.split(args.resume)
109+
output_directory = os.path.dirname(os.path.abspath(args.resume))
55110
print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
111+
train_loader, val_loader = create_data_loaders(args)
112+
args.resume = True
56113

57114
# create new model
58115
else:
59-
# define model
116+
train_loader, val_loader = create_data_loaders(args)
60117
print("=> creating Model ({}-{}) ...".format(args.arch, args.decoder))
61118
in_channels = len(args.modality)
62119
if args.arch == 'resnet50':
63-
model = ResNet(layers=50, decoder=args.decoder, output_size=train_dataset.output_size,
120+
model = ResNet(layers=50, decoder=args.decoder, output_size=train_loader.dataset.output_size,
64121
in_channels=in_channels, pretrained=args.pretrained)
65122
elif args.arch == 'resnet18':
66-
model = ResNet(layers=18, decoder=args.decoder, output_size=train_dataset.output_size,
123+
model = ResNet(layers=18, decoder=args.decoder, output_size=train_loader.dataset.output_size,
67124
in_channels=in_channels, pretrained=args.pretrained)
68125
print("=> model created.")
69-
70126
optimizer = torch.optim.SGD(model.parameters(), args.lr, \
71127
momentum=args.momentum, weight_decay=args.weight_decay)
72128

73-
# create new csv files with only header
74-
with open(train_csv, 'w') as csvfile:
75-
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
76-
writer.writeheader()
77-
with open(test_csv, 'w') as csvfile:
78-
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
79-
writer.writeheader()
80-
81-
# model = torch.nn.DataParallel(model).cuda() # for multi-gpu training
82-
model = model.cuda()
83-
# print(model)
84-
print("=> model transferred to GPU.")
129+
# model = torch.nn.DataParallel(model).cuda() # for multi-gpu training
130+
model = model.cuda()
85131

86132
# define loss function (criterion) and optimizer
87133
if args.criterion == 'l2':
88134
criterion = criteria.MaskedMSELoss().cuda()
89135
elif args.criterion == 'l1':
90136
criterion = criteria.MaskedL1Loss().cuda()
91137

92-
# sparsifier is a class for generating random sparse depth input from the ground truth
93-
sparsifier = None
94-
max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf
95-
if args.sparsifier == UniformSampling.name:
96-
sparsifier = UniformSampling(num_samples=args.num_samples, max_depth=max_depth)
97-
elif args.sparsifier == SimulatedStereo.name:
98-
sparsifier = SimulatedStereo(num_samples=args.num_samples, max_depth=max_depth)
99-
100-
# Data loading code
101-
print("=> creating data loaders ...")
102-
traindir = os.path.join('data', args.data, 'train')
103-
valdir = os.path.join('data', args.data, 'val')
104-
105-
if args.data == 'nyudepthv2':
106-
from dataloaders.nyu_dataloader import NYUDataset
107-
if not args.evaluate:
108-
train_dataset = NYUDataset(traindir, type='train',
109-
modality=args.modality, sparsifier=sparsifier)
110-
val_dataset = NYUDataset(valdir, type='val',
111-
modality=args.modality, sparsifier=sparsifier)
112-
113-
elif args.data == 'kitti':
114-
from dataloaders.kitti_dataloader import KITTIDataset
115-
if not args.evaluate:
116-
train_dataset = KITTIDataset(traindir, type='train',
117-
modality=args.modality, sparsifier=sparsifier)
118-
val_dataset = KITTIDataset(valdir, type='val',
119-
modality=args.modality, sparsifier=sparsifier)
120-
121-
else:
122-
raise RuntimeError('Dataset not found.' +
123-
'The dataset must be either of nyudepthv2 or kitti.')
124-
125-
# set batch size to be 1 for validation
126-
val_loader = torch.utils.data.DataLoader(val_dataset,
127-
batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)
128-
print("=> data loaders created.")
129-
130-
if args.evaluate:
131-
validate(val_loader, model, checkpoint['epoch'], write_to_file=False)
132-
return
133-
134-
# put construction of train loader here, for those who are interested in testing only
135-
train_loader = torch.utils.data.DataLoader(
136-
train_dataset, batch_size=args.batch_size, shuffle=True,
137-
num_workers=args.workers, pin_memory=True, sampler=None,
138-
worker_init_fn=lambda work_id:np.random.seed(work_id))
139-
# worker_init_fn ensures different sampling patterns for each data loading thread
140-
141138
# create results folder, if not already exists
142139
output_directory = utils.get_output_directory(args)
143140
if not os.path.exists(output_directory):
@@ -146,6 +143,15 @@ def main():
146143
test_csv = os.path.join(output_directory, 'test.csv')
147144
best_txt = os.path.join(output_directory, 'best.txt')
148145

146+
# create new csv files with only header
147+
if not args.resume:
148+
with open(train_csv, 'w') as csvfile:
149+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
150+
writer.writeheader()
151+
with open(test_csv, 'w') as csvfile:
152+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
153+
writer.writeheader()
154+
149155
for epoch in range(start_epoch, args.epochs):
150156
utils.adjust_learning_rate(optimizer, epoch, args.lr)
151157
train(train_loader, model, criterion, optimizer, epoch) # train for one epoch

0 commit comments

Comments
 (0)