Skip to content

Commit 2fd5646

Browse files
author
Clément Pinard
committed
add FlowNetC
1 parent 559ab9e commit 2fd5646

File tree

6 files changed

+209
-39
lines changed

6 files changed

+209
-39
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ Two neural network models are currently provided :
1212

1313
- **FlowNetS**
1414
- **FlowNetSBN**
15-
16-
There is currently no implementation of FlowNetC as a specific Correlation layer module would need to be written (feel free to contribute !)
15+
- **FlowNetC**
16+
- **FlowNetCBN**
1717

1818
## Pretrained Models
1919
Thanks to [Kaixhin](https://github.com/Kaixhin) you can download a pretrained version of FlowNetS (from caffe, not from pytorch) [here](https://drive.google.com/open?id=0B5EC7HMbyk3CbjFPb0RuODI3NmM). This folder also contains trained networks from scratch.
@@ -22,13 +22,13 @@ Thanks to [Kaixhin](https://github.com/Kaixhin) you can download a pretrained ve
2222
Directly feed the downloaded Network to the script, you don't need to uncompress it even if your desktop environment tells you so.
2323

2424
### Note on networks from caffe
25-
These networks expect a BGR input in range [-0.5,0.5] (compared to RGB in pytorch). However, BGR order is not very important.
25+
These networks expect a BGR input in range `[-0.5,0.5]` (compared to RGB in pytorch). However, BGR order is not very important.
2626

2727
## Prerequisite
2828

29-
pytorch >= 0.3
29+
pytorch >= 0.4.1
3030
tensorboard-pytorch
31-
tensorboardX
31+
tensorboardX >= 1.4
3232
scipy
3333
argparse
3434

models/FlowNetC.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.nn.init import kaiming_normal_, constant_
4+
from .util import conv, predict_flow, deconv, crop_like, correlate
5+
6+
__all__ = [
7+
'flownetc', 'flownetc_bn'
8+
]
9+
10+
11+
class FlowNetC(nn.Module):
12+
expansion = 1
13+
14+
def __init__(self,batchNorm=True):
15+
super(FlowNetC,self).__init__()
16+
17+
self.batchNorm = batchNorm
18+
self.conv1 = conv(self.batchNorm, 3, 64, kernel_size=7, stride=2)
19+
self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2)
20+
self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2)
21+
self.conv_redir = conv(self.batchNorm, 256, 32, kernel_size=1, stride=1)
22+
23+
self.conv3_1 = conv(self.batchNorm, 473, 256)
24+
self.conv4 = conv(self.batchNorm, 256, 512, stride=2)
25+
self.conv4_1 = conv(self.batchNorm, 512, 512)
26+
self.conv5 = conv(self.batchNorm, 512, 512, stride=2)
27+
self.conv5_1 = conv(self.batchNorm, 512, 512)
28+
self.conv6 = conv(self.batchNorm, 512, 1024, stride=2)
29+
self.conv6_1 = conv(self.batchNorm,1024, 1024)
30+
31+
self.deconv5 = deconv(1024,512)
32+
self.deconv4 = deconv(1026,256)
33+
self.deconv3 = deconv(770,128)
34+
self.deconv2 = deconv(386,64)
35+
36+
self.predict_flow6 = predict_flow(1024)
37+
self.predict_flow5 = predict_flow(1026)
38+
self.predict_flow4 = predict_flow(770)
39+
self.predict_flow3 = predict_flow(386)
40+
self.predict_flow2 = predict_flow(194)
41+
42+
self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
43+
self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
44+
self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
45+
self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
46+
47+
for m in self.modules():
48+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
49+
kaiming_normal_(m.weight, 0.1)
50+
if m.bias is not None:
51+
constant_(m.bias, 0)
52+
elif isinstance(m, nn.BatchNorm2d):
53+
constant_(m.weight, 1)
54+
constant_(m.bias, 0)
55+
56+
def forward(self, x):
57+
x1 = x[:,:3]
58+
x2 = x[:,3:]
59+
60+
out_conv1a = self.conv1(x1)
61+
out_conv2a = self.conv2(out_conv1a)
62+
out_conv3a = self.conv3(out_conv2a)
63+
64+
out_conv1b = self.conv1(x2)
65+
out_conv2b = self.conv2(out_conv1b)
66+
out_conv3b = self.conv3(out_conv2b)
67+
68+
out_conv_redir = self.conv_redir(out_conv3a)
69+
out_correlation = correlate(out_conv3a,out_conv3b)
70+
71+
in_conv3_1 = torch.cat([out_conv_redir, out_correlation], dim=1)
72+
73+
out_conv3 = self.conv3_1(in_conv3_1)
74+
out_conv4 = self.conv4_1(self.conv4(out_conv3))
75+
out_conv5 = self.conv5_1(self.conv5(out_conv4))
76+
out_conv6 = self.conv6_1(self.conv6(out_conv5))
77+
78+
flow6 = self.predict_flow6(out_conv6)
79+
flow6_up = crop_like(self.upsampled_flow6_to_5(flow6), out_conv5)
80+
out_deconv5 = crop_like(self.deconv5(out_conv6), out_conv5)
81+
82+
concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
83+
flow5 = self.predict_flow5(concat5)
84+
flow5_up = crop_like(self.upsampled_flow5_to_4(flow5), out_conv4)
85+
out_deconv4 = crop_like(self.deconv4(concat5), out_conv4)
86+
87+
concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
88+
flow4 = self.predict_flow4(concat4)
89+
flow4_up = crop_like(self.upsampled_flow4_to_3(flow4), out_conv3)
90+
out_deconv3 = crop_like(self.deconv3(concat4), out_conv3)
91+
92+
concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
93+
flow3 = self.predict_flow3(concat3)
94+
flow3_up = crop_like(self.upsampled_flow3_to_2(flow3), out_conv2a)
95+
out_deconv2 = crop_like(self.deconv2(concat3), out_conv2a)
96+
97+
concat2 = torch.cat((out_conv2a,out_deconv2,flow3_up),1)
98+
flow2 = self.predict_flow2(concat2)
99+
100+
if self.training:
101+
return flow2,flow3,flow4,flow5,flow6
102+
else:
103+
return flow2
104+
105+
def weight_parameters(self):
106+
return [param for name, param in self.named_parameters() if 'weight' in name]
107+
108+
def bias_parameters(self):
109+
return [param for name, param in self.named_parameters() if 'bias' in name]
110+
111+
112+
def flownetc(data=None):
113+
"""FlowNetS model architecture from the
114+
"Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852)
115+
116+
Args:
117+
data : pretrained weights of the network. will create a new one if not set
118+
"""
119+
model = FlowNetC(batchNorm=False)
120+
if data is not None:
121+
model.load_state_dict(data['state_dict'])
122+
return model
123+
124+
125+
def flownetc_bn(data=None):
126+
"""FlowNetS model architecture from the
127+
"Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852)
128+
129+
Args:
130+
data : pretrained weights of the network. will create a new one if not set
131+
"""
132+
model = FlowNetC(batchNorm=True)
133+
if data is not None:
134+
model.load_state_dict(data['state_dict'])
135+
return model

models/FlowNetS.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,13 @@
11
import torch
22
import torch.nn as nn
3-
from torch.nn.init import kaiming_normal
3+
from torch.nn.init import kaiming_normal_, constant_
4+
from .util import conv, predict_flow, deconv, crop_like
45

56
__all__ = [
67
'flownets', 'flownets_bn'
78
]
89

910

10-
def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
11-
if batchNorm:
12-
return nn.Sequential(
13-
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False),
14-
nn.BatchNorm2d(out_planes),
15-
nn.LeakyReLU(0.1,inplace=True)
16-
)
17-
else:
18-
return nn.Sequential(
19-
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
20-
nn.LeakyReLU(0.1,inplace=True)
21-
)
22-
23-
24-
def predict_flow(in_planes):
25-
return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=False)
26-
27-
28-
def deconv(in_planes, out_planes):
29-
return nn.Sequential(
30-
nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=False),
31-
nn.LeakyReLU(0.1,inplace=True)
32-
)
33-
34-
35-
def crop_like(input, target):
36-
if input.size()[2:] == target.size()[2:]:
37-
return input
38-
else:
39-
return input[:, :, :target.size(2), :target.size(3)]
40-
41-
4211
class FlowNetS(nn.Module):
4312
expansion = 1
4413

models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .FlowNetS import *
1+
from .FlowNetS import *
2+
from .FlowNetC import *

models/util.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch.nn as nn
2+
import torch.nn.functional as F
3+
4+
try:
5+
from spatial_correlation_sampler import spatial_correlation_sample
6+
except ImportError as e:
7+
import warnings
8+
with warnings.catch_warnings():
9+
warnings.filterwarnings("default", category=ImportWarning)
10+
warnings.warn("failed to load custom correlation module"
11+
"which is needed for FlowNetC", ImportWarning)
12+
13+
14+
def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
15+
if batchNorm:
16+
return nn.Sequential(
17+
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False),
18+
nn.BatchNorm2d(out_planes),
19+
nn.LeakyReLU(0.1,inplace=True)
20+
)
21+
else:
22+
return nn.Sequential(
23+
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
24+
nn.LeakyReLU(0.1,inplace=True)
25+
)
26+
27+
28+
def predict_flow(in_planes):
29+
return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=False)
30+
31+
32+
def deconv(in_planes, out_planes):
33+
return nn.Sequential(
34+
nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=False),
35+
nn.LeakyReLU(0.1,inplace=True)
36+
)
37+
38+
39+
def correlate(input1, input2):
40+
out_corr = spatial_correlation_sample(input1,
41+
input2,
42+
kernel_size=1,
43+
patch_size=21,
44+
stride=1,
45+
padding=0,
46+
dilation_patch=2)
47+
# collate dimensions 1 and 2 in order to be treated as a
48+
# regular 4D tensor
49+
b, ph, pw, h, w = out_corr.size()
50+
out_corr = out_corr.view(b, ph * pw, h, w)/input1.size(1)
51+
return F.leaky_relu_(out_corr, 0.1)
52+
53+
54+
def crop_like(input, target):
55+
if input.size()[2:] == target.size()[2:]:
56+
return input
57+
else:
58+
return input[:, :, :target.size(2), :target.size(3)]

requirements.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
torch>=0.4.1
2+
torchvision
3+
numpy
4+
spatial-correlation-sampler
5+
tensorboardX>=1.4
6+
imageio
7+
argparse

0 commit comments

Comments
 (0)