Skip to content

Commit a5758aa

Browse files
committed
[Correlation] cpu version
1 parent 1ba9724 commit a5758aa

File tree

1 file changed

+175
-1
lines changed

1 file changed

+175
-1
lines changed

src/caffe/layers/correlation_layer.cpp

Lines changed: 175 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,136 @@
99

1010
namespace caffe {
1111

12+
template<typename Dtype>
13+
void blob_rearrange_kernel2(Dtype *dout, const Dtype *din, int num, int channels, int height, int width, int pheight, int pwidth, int padding)
14+
{
15+
// dout[num][pwidthheight][widthheight][channels]
16+
// din[num][channels][height][width]
17+
18+
for (int n = 0; n < num; ++n)
19+
for (int ch = 0; ch < channels; ++ch)
20+
for (int y = 0; y < height; ++y)
21+
for (int x = 0; x < width; ++x)
22+
dout[((n * pheight + y + padding) * pwidth + x + padding) * channels + ch] =
23+
din[((n * channels + ch) * height + y) * width + x];
24+
}
25+
26+
template <typename Dtype>
27+
void CorrelateData(int num, int topwidth, int topheight, int topchannels, int topcount,
28+
int max_displacement, int neighborhood_grid_radius, int neighborhood_grid_width, int kernel_radius, int kernel_size, int stride1, int stride2,
29+
int bottomwidth, int bottomheight, int bottomchannels,
30+
const Dtype *bottom0, const Dtype *bottom1, Dtype *top)
31+
{
32+
33+
34+
for (int n = 0; n < num; ++n)
35+
{
36+
Dtype patch_data[kernel_size * kernel_size * bottomchannels];
37+
38+
for (int y = 0; y < topheight; ++y)
39+
for (int x = 0; x < topwidth; ++x)
40+
{
41+
int x1 = x * stride1 + max_displacement;
42+
int y1 = y * stride1 + max_displacement;
43+
44+
// Load 3D patch into shared shared memory
45+
for (int j = 0; j < kernel_size; j++) // HEIGHT
46+
for (int i = 0; i < kernel_size; i++) // WIDTH
47+
{
48+
int ji_off = ((j * kernel_size) + i) * bottomchannels;
49+
for (int ch = 0; ch < bottomchannels; ch++) // CHANNELS
50+
{
51+
int idx1 = ((n * bottomheight + y1 + j) * bottomwidth + x1 + i) * bottomchannels + ch;
52+
int idxPatchData = ji_off + ch;
53+
patch_data[idxPatchData] = bottom0[idx1];
54+
}
55+
}
56+
57+
// Compute correlation
58+
for (int top_channel = 0; top_channel < topchannels; top_channel++)
59+
{
60+
double sum = 0;
61+
62+
int s2o = (top_channel % neighborhood_grid_width - neighborhood_grid_radius) * stride2;
63+
int s2p = (top_channel / neighborhood_grid_width - neighborhood_grid_radius) * stride2;
64+
65+
for (int j = 0; j < kernel_size; j++) // HEIGHT
66+
for (int i = 0; i < kernel_size; i++) // WIDTH
67+
{
68+
int ji_off = ((j * kernel_size) + i) * bottomchannels;
69+
for (int ch = 0; ch < bottomchannels; ch++) // CHANNELS
70+
{
71+
int x2 = x1 + s2o;
72+
int y2 = y1 + s2p;
73+
74+
int idxPatchData = ji_off + ch;
75+
int idx2 = ((n * bottomheight + y2 + j) * bottomwidth + x2 + i) * bottomchannels + ch;
76+
77+
sum += patch_data[idxPatchData] * bottom1[idx2];
78+
}
79+
}
80+
81+
82+
83+
84+
const int sumelems = kernel_size * kernel_size * bottomchannels;
85+
const int index = ((top_channel * topheight + y) * topwidth) + x;
86+
87+
// printf("%f\n", sum / (float) sumelems);
88+
89+
top[index + n * topcount] = sum / (float) sumelems;
90+
}
91+
}
92+
}
93+
// Aggregate
94+
}
95+
96+
template <typename Dtype>
97+
void CorrelateDataSubtract(int num, int item, int topwidth, int topheight, int topchannels, int topcount,
98+
int max_displacement, int neighborhood_grid_radius, int neighborhood_grid_width, int kernel_radius, int stride1, int stride2,
99+
int bottomwidth, int bottomheight, int bottomchannels,
100+
const Dtype *bottom0, const Dtype *bottom1, Dtype *top)
101+
{
102+
for (int index = 0; index < topcount; index++)
103+
{
104+
int x = index % topwidth; //w-pos
105+
int y = (index / topwidth) % topheight; //h-pos
106+
int c = (index / topwidth / topheight) % topchannels; //channels
107+
108+
// Offset of patch in image 2
109+
int s2o = (c % neighborhood_grid_width - neighborhood_grid_radius) * stride2;
110+
int s2p = (c / neighborhood_grid_width - neighborhood_grid_radius) * stride2;
111+
112+
// First (upper left) position of kernel center in current neighborhood in image 1
113+
int x1 = x * stride1 + kernel_radius + max_displacement;
114+
int y1 = y * stride1 + kernel_radius + max_displacement;
115+
116+
// Iterate through 3D patch
117+
Dtype sum = 0;
118+
for (int j = -kernel_radius; j <= kernel_radius; j++)
119+
{ // HEIGHT
120+
for (int i = -kernel_radius; i <= kernel_radius; i++)
121+
{ // WIDTH
122+
for (int l = 0; l < bottomchannels; l++)
123+
{ // CHANNELS
124+
// Calculate position in image 2
125+
int x2 = x1 + s2o;
126+
int y2 = y1 + s2p;
127+
128+
// Indices in bottom data: (CH=l,W=x2,H=y2,N)
129+
int idx1 = ((item * bottomheight + y1 + j) * bottomwidth + x1 + i) * bottomchannels + l;
130+
int idx2 = ((item * bottomheight + y2 + j) * bottomwidth + x2 + i) * bottomchannels + l;
131+
132+
// Do the correlation:
133+
sum += abs(bottom0[idx1] - bottom1[idx2]);
134+
}
135+
}
136+
}
137+
const int sumelems = (kernel_radius * 2 + 1)*(kernel_radius * 2 + 1) * bottomchannels;
138+
top[index + item * topcount] = sum / (float) sumelems;
139+
}
140+
}
141+
12142
template <typename Dtype>
13143
void CorrelationLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
14144
const vector<Blob<Dtype>*>& top) {
@@ -86,7 +216,51 @@ void CorrelationLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
86216
template <typename Dtype>
87217
void CorrelationLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
88218
const vector<Blob<Dtype>*>& top) {
89-
NOT_IMPLEMENTED;
219+
220+
CHECK_EQ(bottom.size(), 2);
221+
CHECK_EQ(top.size(), 1);
222+
223+
const int bnum = bottom[0]->num();
224+
const int bchannels = bottom[0]->channels();
225+
const int bheight = bottom[0]->height();
226+
const int bwidth = bottom[0]->width();
227+
228+
const int topcount = top_width_ * top_height_ * top_channels_;
229+
const int pheight = bheight + 2 * pad_size_;
230+
const int pwidth = bwidth + 2 * pad_size_;
231+
232+
blob_rearrange_kernel2<Dtype>(rbot1_->mutable_cpu_data(), bottom[0]->cpu_data(), bnum, bchannels, bheight, bwidth, pheight, pwidth, pad_size_);
233+
blob_rearrange_kernel2<Dtype>(rbot2_->mutable_cpu_data(), bottom[1]->cpu_data(), bnum, bchannels, bheight, bwidth, pheight, pwidth, pad_size_);
234+
235+
const int num = bnum;
236+
const int channels = bchannels;
237+
const int height = bheight + 2 * pad_size_;
238+
const int width = bwidth + 2 * pad_size_;
239+
240+
if(corr_type_ == CorrelationParameter_CorrelationType_MULTIPLY) {
241+
// CorrelationLayer
242+
243+
CorrelateData<Dtype>(
244+
num, top_width_, top_height_, top_channels_, topcount,
245+
max_displacement_, neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_, kernel_size_,
246+
stride1_, stride2_,
247+
width, height, channels,
248+
rbot1_->cpu_data(), rbot2_->cpu_data(), top[0]->mutable_cpu_data()
249+
);
250+
251+
} else if(corr_type_ == CorrelationParameter_CorrelationType_SUBTRACT) {
252+
// CorrelationLayer
253+
for(int n = 0; n < num; n++) {
254+
255+
CorrelateDataSubtract<Dtype>(
256+
num, n, top_width_, top_height_, top_channels_, topcount,
257+
max_displacement_, neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_,
258+
stride1_, stride2_,
259+
width, height, channels,
260+
rbot1_->cpu_data(), rbot2_->cpu_data(), top[0]->mutable_cpu_data()
261+
);
262+
}
263+
}
90264
}
91265

92266
template <typename Dtype>

0 commit comments

Comments
 (0)