|
28 | 28 | parser.add_argument("--img-exts", default=['png', 'jpg', 'bmp'], nargs='*', type=str, help="images extensions to glob") |
29 | 29 | parser.add_argument('--max_flow', default=None, type=float, |
30 | 30 | help='max flow value. Flow map color is saturated above this value. If not set, will use flow map\'s max value') |
31 | | - |
32 | | -best_EPE = -1 |
33 | | -n_iter = 0 |
| 31 | +parser.add_argument('--no-resize', action='store_true', help='if set, will output FlowNet raw input, which is 4 times downsampled.' |
| 32 | + 'if not set, will output full resolution flow map, with bilinear upsampling') |
| 33 | +parser.add_argument('--bidirectional', action='store_true', help='if set, will output invert flow (from 1 to 0) along with regular flow') |
34 | 34 |
|
35 | 35 |
|
36 | 36 | def main(): |
37 | | - global args, best_EPE, save_path |
| 37 | + global args, save_path |
38 | 38 | args = parser.parse_args() |
39 | 39 | data_dir = Path(args.data) |
40 | 40 | print("=> fetching img pairs in '{}'".format(args.data)) |
@@ -72,13 +72,21 @@ def main(): |
72 | 72 |
|
73 | 73 | img1 = input_transform(imread(img1_file)) |
74 | 74 | img2 = input_transform(imread(img2_file)) |
75 | | - input_var = torch.autograd.Variable(torch.cat([img1, img2],0).cuda(), volatile=True).unsqueeze(0) |
| 75 | + input_var = torch.autograd.Variable(torch.cat([img1, img2]).cuda(), volatile=True).unsqueeze(0) |
| 76 | + |
| 77 | + if args.bidirectional: |
| 78 | + # feed inverted pair along with normal pair |
| 79 | + inverted_input_var = torch.autograd.Variable(torch.cat([img2, img1],0).cuda(), volatile=True).unsqueeze(0) |
| 80 | + input_var = torch.cat([input_var, inverted_input_var]) |
76 | 81 |
|
77 | 82 | # compute output |
78 | 83 | output = model(input_var) |
79 | | - rgb_flow = flow2rgb(args.div_flow * output.data[0].cpu().numpy(), max_value=args.max_flow) |
80 | | - to_save = (rgb_flow * 255).astype(np.uint8) |
81 | | - imsave(save_path/(img1_file.namebase[:-2] + '_flow.png'), to_save) |
| 84 | + if not args.no_resize: |
| 85 | + output = torch.nn.functional.upsample(output, size=img1.size()[-2:], mode='bilinear') |
| 86 | + for suffix, flow_output in zip(['flow', 'inv_flow'], output.data.cpu()): |
| 87 | + rgb_flow = flow2rgb(args.div_flow * flow_output.numpy(), max_value=args.max_flow) |
| 88 | + to_save = (rgb_flow * 255).astype(np.uint8) |
| 89 | + imsave(save_path/'{}{}.png'.format(img1_file.namebase[:-1], suffix), to_save) |
82 | 90 |
|
83 | 91 |
|
84 | 92 | if __name__ == '__main__': |
|
0 commit comments