Skip to content

Commit d559735

Browse files
committed
Modifications based on PR #91 review
1 parent cb1659b commit d559735

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed

main.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,7 @@ def main():
163163
network_data = None
164164
print("=> creating model '{}'".format(args.arch))
165165

166-
if device.type == "cuda":
167-
model = models.__dict__[args.arch](network_data).cuda()
168-
else:
169-
model = models.__dict__[args.arch](network_data).cpu()
166+
model = models.__dict__[args.arch](network_data).to(device)
170167

171168
assert(args.solver in ['adam', 'sgd'])
172169
print('=> setting {} solver'.format(args.solver))

run_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def main():
7171
for ext in args.img_exts:
7272
test_files = data_dir.files('*1.{}'.format(ext))
7373
for file in test_files:
74-
img_pair = file.parent / (file.basename().splitext()[0][:-1] + '2.{}'.format(ext))
74+
img_pair = file.parent / (file.stem[:-1] + '2.{}'.format(ext))
7575
if img_pair.isfile():
7676
img_pairs.append([file, img_pair])
7777

@@ -103,7 +103,7 @@ def main():
103103
if args.upsampling is not None:
104104
output = F.interpolate(output, size=img1.size()[-2:], mode=args.upsampling, align_corners=False)
105105
for suffix, flow_output in zip(['flow', 'inv_flow'], output):
106-
filename = save_path/'{}{}'.format(img1_file.basename().splitext()[0][:-1], suffix)
106+
filename = save_path/'{}{}'.format(img1_file.stem[:-1], suffix)
107107
if args.output_value in['vis', 'both']:
108108
rgb_flow = flow2rgb(args.div_flow * flow_output, max_value=args.max_flow)
109109
to_save = (rgb_flow * 255).astype(np.uint8).transpose(1,2,0)

0 commit comments

Comments
 (0)