Skip to content

Commit 0154616

Browse files
Merge pull request #91 from mickaelseznec/master
Run on CPU + minor updates
2 parents 53248d1 + d559735 commit 0154616

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

main.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,17 @@ def main():
163163
network_data = None
164164
print("=> creating model '{}'".format(args.arch))
165165

166-
model = models.__dict__[args.arch](network_data).cuda()
167-
model = torch.nn.DataParallel(model).cuda()
168-
cudnn.benchmark = True
166+
model = models.__dict__[args.arch](network_data).to(device)
169167

170168
assert(args.solver in ['adam', 'sgd'])
171169
print('=> setting {} solver'.format(args.solver))
172-
param_groups = [{'params': model.module.bias_parameters(), 'weight_decay': args.bias_decay},
173-
{'params': model.module.weight_parameters(), 'weight_decay': args.weight_decay}]
170+
param_groups = [{'params': model.bias_parameters(), 'weight_decay': args.bias_decay},
171+
{'params': model.weight_parameters(), 'weight_decay': args.weight_decay}]
172+
173+
if device.type == "cuda":
174+
model = torch.nn.DataParallel(model).cuda()
175+
cudnn.benchmark = True
176+
174177
if args.solver == 'adam':
175178
optimizer = torch.optim.Adam(param_groups, args.lr,
176179
betas=(args.momentum, args.beta))

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@ spatial-correlation-sampler>=0.2.1
55
tensorboardX>=1.4
66
imageio
77
argparse
8-
path.py
8+
path
9+
tqdm
10+
scipy

run_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def main():
7171
for ext in args.img_exts:
7272
test_files = data_dir.files('*1.{}'.format(ext))
7373
for file in test_files:
74-
img_pair = file.parent / (file.namebase[:-1] + '2.{}'.format(ext))
74+
img_pair = file.parent / (file.stem[:-1] + '2.{}'.format(ext))
7575
if img_pair.isfile():
7676
img_pairs.append([file, img_pair])
7777

@@ -103,7 +103,7 @@ def main():
103103
if args.upsampling is not None:
104104
output = F.interpolate(output, size=img1.size()[-2:], mode=args.upsampling, align_corners=False)
105105
for suffix, flow_output in zip(['flow', 'inv_flow'], output):
106-
filename = save_path/'{}{}'.format(img1_file.namebase[:-1], suffix)
106+
filename = save_path/'{}{}'.format(img1_file.stem[:-1], suffix)
107107
if args.output_value in['vis', 'both']:
108108
rgb_flow = flow2rgb(args.div_flow * flow_output, max_value=args.max_flow)
109109
to_save = (rgb_flow * 255).astype(np.uint8).transpose(1,2,0)

0 commit comments

Comments
 (0)