Skip to content

Commit 678ac41

Browse files
author
Clément Pinard
committed
solves #51
Fx run_inference.py script it is now pytorch 0.4.1 compliant, especially with new 'flow2rg' function
1 parent 4e44aa6 commit 678ac41

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

run_inference.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
import torch.backends.cudnn as cudnn
6+
import torch.nn.functional as F
67
import models
78
from tqdm import tqdm
89
import torchvision.transforms as transforms
@@ -32,13 +33,19 @@
3233
'which is 4 times downsampled. If set, will output full resolution flow map, with selected upsampling')
3334
parser.add_argument('--bidirectional', action='store_true', help='if set, will output invert flow (from 1 to 0) along with regular flow')
3435

36+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
37+
38+
3539
@torch.no_grad()
3640
def main():
3741
global args, save_path
3842
args = parser.parse_args()
3943
data_dir = Path(args.data)
4044
print("=> fetching img pairs in '{}'".format(args.data))
41-
save_path = data_dir/'flow'
45+
if args.output is None:
46+
save_path = data_dir/'flow'
47+
else:
48+
save_path = Path(args.output)
4249
print('=> will save everything to {}'.format(save_path))
4350
save_path.makedirs_p()
4451

@@ -61,7 +68,7 @@ def main():
6168
# create model
6269
network_data = torch.load(args.pretrained)
6370
print("=> using pre-trained model '{}'".format(network_data['arch']))
64-
model = models.__dict__[network_data['arch']](network_data).cuda()
71+
model = models.__dict__[network_data['arch']](network_data).to(device)
6572
model.eval()
6673
cudnn.benchmark = True
6774

@@ -72,19 +79,20 @@ def main():
7279

7380
img1 = input_transform(imread(img1_file))
7481
img2 = input_transform(imread(img2_file))
75-
input_var = torch.tensor(torch.cat([img1, img2]).cuda()).unsqueeze(0)
82+
input_var = torch.cat([img1, img2]).unsqueeze(0)
7683

7784
if args.bidirectional:
7885
# feed inverted pair along with normal pair
79-
inverted_input_var = torch.tensor(torch.cat([img2, img1],0).cuda()).unsqueeze(0)
86+
inverted_input_var = torch.cat([img2, img1]).unsqueeze(0)
8087
input_var = torch.cat([input_var, inverted_input_var])
8188

89+
input_var = input_var.to(device)
8290
# compute output
8391
output = model(input_var)
8492
if args.upsampling is not None:
85-
output = torch.nn.functional.upsample(output, size=img1.size()[-2:], mode=args.upsampling)
86-
for suffix, flow_output in zip(['flow', 'inv_flow'], output.data.cpu()):
87-
rgb_flow = flow2rgb(args.div_flow * flow_output.numpy(), max_value=args.max_flow)
93+
output = F.interpolate(output, size=img1.size()[-2:], mode=args.upsampling, align_corners=False)
94+
for suffix, flow_output in zip(['flow', 'inv_flow'], output):
95+
rgb_flow = flow2rgb(args.div_flow * flow_output, max_value=args.max_flow)
8896
to_save = (rgb_flow * 255).astype(np.uint8).transpose(1,2,0)
8997
imsave(save_path/'{}{}.png'.format(img1_file.namebase[:-1], suffix), to_save)
9098

0 commit comments

Comments
 (0)