Skip to content

Commit afb354c

Browse files
Merge pull request #46 from bkvie/patch-2
Update run_inference.py
2 parents 9868d72 + fcb92dc commit afb354c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

run_inference.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
'which is 4 times downsampled. If set, will output full resolution flow map, with selected upsampling')
3333
parser.add_argument('--bidirectional', action='store_true', help='if set, will output invert flow (from 1 to 0) along with regular flow')
3434

35-
35+
@torch.no_grad()
3636
def main():
3737
global args, save_path
3838
args = parser.parse_args()
@@ -72,11 +72,11 @@ def main():
7272

7373
img1 = input_transform(imread(img1_file))
7474
img2 = input_transform(imread(img2_file))
75-
input_var = torch.autograd.Variable(torch.cat([img1, img2]).cuda(), volatile=True).unsqueeze(0)
75+
input_var = torch.tensor(torch.cat([img1, img2]).cuda()).unsqueeze(0)
7676

7777
if args.bidirectional:
7878
# feed inverted pair along with normal pair
79-
inverted_input_var = torch.autograd.Variable(torch.cat([img2, img1],0).cuda(), volatile=True).unsqueeze(0)
79+
inverted_input_var = torch.tensor(torch.cat([img2, img1],0).cuda()).unsqueeze(0)
8080
input_var = torch.cat([input_var, inverted_input_var])
8181

8282
# compute output
@@ -85,7 +85,7 @@ def main():
8585
output = torch.nn.functional.upsample(output, size=img1.size()[-2:], mode=args.upsampling)
8686
for suffix, flow_output in zip(['flow', 'inv_flow'], output.data.cpu()):
8787
rgb_flow = flow2rgb(args.div_flow * flow_output.numpy(), max_value=args.max_flow)
88-
to_save = (rgb_flow * 255).astype(np.uint8)
88+
to_save = (rgb_flow * 255).astype(np.uint8).transpose(1,2,0)
8989
imsave(save_path/'{}{}.png'.format(img1_file.namebase[:-1], suffix), to_save)
9090

9191

0 commit comments

Comments
 (0)