Skip to content

Commit 24bdd7b

Browse files
add hf integration
1 parent 6bdd214 commit 24bdd7b

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed

infer_hf.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import torch
2+
from glob import glob
3+
import os
4+
import numpy as np
5+
import cv2
6+
from NeuFlow.neuflow import NeuFlow
7+
from NeuFlow.backbone_v7 import ConvBlock
8+
from data_utils import flow_viz
9+
10+
11+
image_width = 768
12+
image_height = 432
13+
14+
def get_cuda_image(image_path):
15+
image = cv2.imread(image_path)
16+
17+
image = cv2.resize(image, (image_width, image_height))
18+
19+
image = torch.from_numpy(image).permute(2, 0, 1).half()
20+
return image[None].cuda()
21+
22+
23+
def fuse_conv_and_bn(conv, bn):
24+
"""Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
25+
fusedconv = (
26+
torch.nn.Conv2d(
27+
conv.in_channels,
28+
conv.out_channels,
29+
kernel_size=conv.kernel_size,
30+
stride=conv.stride,
31+
padding=conv.padding,
32+
dilation=conv.dilation,
33+
groups=conv.groups,
34+
bias=True,
35+
)
36+
.requires_grad_(False)
37+
.to(conv.weight.device)
38+
)
39+
40+
# Prepare filters
41+
w_conv = conv.weight.clone().view(conv.out_channels, -1)
42+
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
43+
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
44+
45+
# Prepare spatial bias
46+
b_conv = torch.zeros(conv.weight.shape[0], device=conv.weight.device) if conv.bias is None else conv.bias
47+
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
48+
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
49+
50+
return fusedconv
51+
52+
53+
image_path_list = sorted(glob('test_images/*.jpg'))
54+
vis_path = 'test_results/'
55+
56+
device = torch.device('cuda')
57+
58+
model = NeuFlow.from_pretrained("Study-is-happy/neuflow-v2").to(device)
59+
60+
for m in model.modules():
61+
if type(m) is ConvBlock:
62+
m.conv1 = fuse_conv_and_bn(m.conv1, m.norm1) # update conv
63+
m.conv2 = fuse_conv_and_bn(m.conv2, m.norm2) # update conv
64+
delattr(m, "norm1") # remove batchnorm
65+
delattr(m, "norm2") # remove batchnorm
66+
m.forward = m.forward_fuse # update forward
67+
68+
model.eval()
69+
model.half()
70+
71+
model.init_bhwd(1, image_height, image_width, 'cuda')
72+
73+
if not os.path.exists(vis_path):
74+
os.makedirs(vis_path)
75+
76+
for image_path_0, image_path_1 in zip(image_path_list[:-1], image_path_list[1:]):
77+
78+
print(image_path_0)
79+
80+
image_0 = get_cuda_image(image_path_0)
81+
image_1 = get_cuda_image(image_path_1)
82+
83+
file_name = os.path.basename(image_path_0)
84+
85+
with torch.no_grad():
86+
87+
flow = model(image_0, image_1)[-1][0]
88+
89+
flow = flow.permute(1,2,0).cpu().numpy()
90+
91+
flow = flow_viz.flow_to_image(flow)
92+
93+
image_0 = cv2.resize(cv2.imread(image_path_0), (image_width, image_height))
94+
95+
cv2.imwrite(vis_path + file_name, np.vstack([image_0, flow]))
96+

0 commit comments

Comments
 (0)