Skip to content

Commit c6a079c

Browse files
author
Clément Pinard
committed
add inverted flow and upscaling
1 parent 7ebd50c commit c6a079c

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

run_inference.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@
2828
parser.add_argument("--img-exts", default=['png', 'jpg', 'bmp'], nargs='*', type=str, help="images extensions to glob")
2929
parser.add_argument('--max_flow', default=None, type=float,
3030
help='max flow value. Flow map color is saturated above this value. If not set, will use flow map\'s max value')
31-
32-
best_EPE = -1
33-
n_iter = 0
31+
parser.add_argument('--no-resize', action='store_true', help='if set, will output FlowNet raw input, which is 4 times downsampled.'
32+
'if not set, will output full resolution flow map, with bilinear upsampling')
33+
parser.add_argument('--bidirectional', action='store_true', help='if set, will output invert flow (from 1 to 0) along with regular flow')
3434

3535

3636
def main():
37-
global args, best_EPE, save_path
37+
global args, save_path
3838
args = parser.parse_args()
3939
data_dir = Path(args.data)
4040
print("=> fetching img pairs in '{}'".format(args.data))
@@ -72,13 +72,21 @@ def main():
7272

7373
img1 = input_transform(imread(img1_file))
7474
img2 = input_transform(imread(img2_file))
75-
input_var = torch.autograd.Variable(torch.cat([img1, img2],0).cuda(), volatile=True).unsqueeze(0)
75+
input_var = torch.autograd.Variable(torch.cat([img1, img2]).cuda(), volatile=True).unsqueeze(0)
76+
77+
if args.bidirectional:
78+
# feed inverted pair along with normal pair
79+
inverted_input_var = torch.autograd.Variable(torch.cat([img2, img1],0).cuda(), volatile=True).unsqueeze(0)
80+
input_var = torch.cat([input_var, inverted_input_var])
7681

7782
# compute output
7883
output = model(input_var)
79-
rgb_flow = flow2rgb(args.div_flow * output.data[0].cpu().numpy(), max_value=args.max_flow)
80-
to_save = (rgb_flow * 255).astype(np.uint8)
81-
imsave(save_path/(img1_file.namebase[:-2] + '_flow.png'), to_save)
84+
if not args.no_resize:
85+
output = torch.nn.functional.upsample(output, size=img1.size()[-2:], mode='bilinear')
86+
for suffix, flow_output in zip(['flow', 'inv_flow'], output.data.cpu()):
87+
rgb_flow = flow2rgb(args.div_flow * flow_output.numpy(), max_value=args.max_flow)
88+
to_save = (rgb_flow * 255).astype(np.uint8)
89+
imsave(save_path/'{}{}.png'.format(img1_file.namebase[:-1], suffix), to_save)
8290

8391

8492
if __name__ == '__main__':

0 commit comments

Comments
 (0)