|
9 | 9 |
|
10 | 10 | namespace caffe { |
11 | 11 |
|
| 12 | +template<typename Dtype> |
| 13 | +void blob_rearrange_kernel2(Dtype *dout, const Dtype *din, int num, int channels, int height, int width, int pheight, int pwidth, int padding) |
| 14 | +{ |
| 15 | + // dout[num][pwidthheight][widthheight][channels] |
| 16 | + // din[num][channels][height][width] |
| 17 | + |
| 18 | + for (int n = 0; n < num; ++n) |
| 19 | + for (int ch = 0; ch < channels; ++ch) |
| 20 | + for (int y = 0; y < height; ++y) |
| 21 | + for (int x = 0; x < width; ++x) |
| 22 | + dout[((n * pheight + y + padding) * pwidth + x + padding) * channels + ch] = |
| 23 | + din[((n * channels + ch) * height + y) * width + x]; |
| 24 | +} |
| 25 | + |
| 26 | +template <typename Dtype> |
| 27 | +void CorrelateData(int num, int topwidth, int topheight, int topchannels, int topcount, |
| 28 | + int max_displacement, int neighborhood_grid_radius, int neighborhood_grid_width, int kernel_radius, int kernel_size, int stride1, int stride2, |
| 29 | + int bottomwidth, int bottomheight, int bottomchannels, |
| 30 | + const Dtype *bottom0, const Dtype *bottom1, Dtype *top) |
| 31 | +{ |
| 32 | + |
| 33 | + |
| 34 | + for (int n = 0; n < num; ++n) |
| 35 | + { |
| 36 | + Dtype patch_data[kernel_size * kernel_size * bottomchannels]; |
| 37 | + |
| 38 | + for (int y = 0; y < topheight; ++y) |
| 39 | + for (int x = 0; x < topwidth; ++x) |
| 40 | + { |
| 41 | + int x1 = x * stride1 + max_displacement; |
| 42 | + int y1 = y * stride1 + max_displacement; |
| 43 | + |
| 44 | + // Load 3D patch into shared shared memory |
| 45 | + for (int j = 0; j < kernel_size; j++) // HEIGHT |
| 46 | + for (int i = 0; i < kernel_size; i++) // WIDTH |
| 47 | + { |
| 48 | + int ji_off = ((j * kernel_size) + i) * bottomchannels; |
| 49 | + for (int ch = 0; ch < bottomchannels; ch++) // CHANNELS |
| 50 | + { |
| 51 | + int idx1 = ((n * bottomheight + y1 + j) * bottomwidth + x1 + i) * bottomchannels + ch; |
| 52 | + int idxPatchData = ji_off + ch; |
| 53 | + patch_data[idxPatchData] = bottom0[idx1]; |
| 54 | + } |
| 55 | + } |
| 56 | + |
| 57 | + // Compute correlation |
| 58 | + for (int top_channel = 0; top_channel < topchannels; top_channel++) |
| 59 | + { |
| 60 | + double sum = 0; |
| 61 | + |
| 62 | + int s2o = (top_channel % neighborhood_grid_width - neighborhood_grid_radius) * stride2; |
| 63 | + int s2p = (top_channel / neighborhood_grid_width - neighborhood_grid_radius) * stride2; |
| 64 | + |
| 65 | + for (int j = 0; j < kernel_size; j++) // HEIGHT |
| 66 | + for (int i = 0; i < kernel_size; i++) // WIDTH |
| 67 | + { |
| 68 | + int ji_off = ((j * kernel_size) + i) * bottomchannels; |
| 69 | + for (int ch = 0; ch < bottomchannels; ch++) // CHANNELS |
| 70 | + { |
| 71 | + int x2 = x1 + s2o; |
| 72 | + int y2 = y1 + s2p; |
| 73 | + |
| 74 | + int idxPatchData = ji_off + ch; |
| 75 | + int idx2 = ((n * bottomheight + y2 + j) * bottomwidth + x2 + i) * bottomchannels + ch; |
| 76 | + |
| 77 | + sum += patch_data[idxPatchData] * bottom1[idx2]; |
| 78 | + } |
| 79 | + } |
| 80 | + |
| 81 | + |
| 82 | + |
| 83 | + |
| 84 | + const int sumelems = kernel_size * kernel_size * bottomchannels; |
| 85 | + const int index = ((top_channel * topheight + y) * topwidth) + x; |
| 86 | + |
| 87 | + // printf("%f\n", sum / (float) sumelems); |
| 88 | + |
| 89 | + top[index + n * topcount] = sum / (float) sumelems; |
| 90 | + } |
| 91 | + } |
| 92 | + } |
| 93 | + // Aggregate |
| 94 | +} |
| 95 | + |
| 96 | +template <typename Dtype> |
| 97 | +void CorrelateDataSubtract(int num, int item, int topwidth, int topheight, int topchannels, int topcount, |
| 98 | + int max_displacement, int neighborhood_grid_radius, int neighborhood_grid_width, int kernel_radius, int stride1, int stride2, |
| 99 | + int bottomwidth, int bottomheight, int bottomchannels, |
| 100 | + const Dtype *bottom0, const Dtype *bottom1, Dtype *top) |
| 101 | +{ |
| 102 | + for (int index = 0; index < topcount; index++) |
| 103 | + { |
| 104 | + int x = index % topwidth; //w-pos |
| 105 | + int y = (index / topwidth) % topheight; //h-pos |
| 106 | + int c = (index / topwidth / topheight) % topchannels; //channels |
| 107 | + |
| 108 | + // Offset of patch in image 2 |
| 109 | + int s2o = (c % neighborhood_grid_width - neighborhood_grid_radius) * stride2; |
| 110 | + int s2p = (c / neighborhood_grid_width - neighborhood_grid_radius) * stride2; |
| 111 | + |
| 112 | + // First (upper left) position of kernel center in current neighborhood in image 1 |
| 113 | + int x1 = x * stride1 + kernel_radius + max_displacement; |
| 114 | + int y1 = y * stride1 + kernel_radius + max_displacement; |
| 115 | + |
| 116 | + // Iterate through 3D patch |
| 117 | + Dtype sum = 0; |
| 118 | + for (int j = -kernel_radius; j <= kernel_radius; j++) |
| 119 | + { // HEIGHT |
| 120 | + for (int i = -kernel_radius; i <= kernel_radius; i++) |
| 121 | + { // WIDTH |
| 122 | + for (int l = 0; l < bottomchannels; l++) |
| 123 | + { // CHANNELS |
| 124 | + // Calculate position in image 2 |
| 125 | + int x2 = x1 + s2o; |
| 126 | + int y2 = y1 + s2p; |
| 127 | + |
| 128 | + // Indices in bottom data: (CH=l,W=x2,H=y2,N) |
| 129 | + int idx1 = ((item * bottomheight + y1 + j) * bottomwidth + x1 + i) * bottomchannels + l; |
| 130 | + int idx2 = ((item * bottomheight + y2 + j) * bottomwidth + x2 + i) * bottomchannels + l; |
| 131 | + |
| 132 | + // Do the correlation: |
| 133 | + sum += abs(bottom0[idx1] - bottom1[idx2]); |
| 134 | + } |
| 135 | + } |
| 136 | + } |
| 137 | + const int sumelems = (kernel_radius * 2 + 1)*(kernel_radius * 2 + 1) * bottomchannels; |
| 138 | + top[index + item * topcount] = sum / (float) sumelems; |
| 139 | + } |
| 140 | +} |
| 141 | + |
12 | 142 | template <typename Dtype> |
13 | 143 | void CorrelationLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, |
14 | 144 | const vector<Blob<Dtype>*>& top) { |
@@ -86,7 +216,51 @@ void CorrelationLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom, |
86 | 216 | template <typename Dtype> |
87 | 217 | void CorrelationLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, |
88 | 218 | const vector<Blob<Dtype>*>& top) { |
89 | | - NOT_IMPLEMENTED; |
| 219 | + |
| 220 | + CHECK_EQ(bottom.size(), 2); |
| 221 | + CHECK_EQ(top.size(), 1); |
| 222 | + |
| 223 | + const int bnum = bottom[0]->num(); |
| 224 | + const int bchannels = bottom[0]->channels(); |
| 225 | + const int bheight = bottom[0]->height(); |
| 226 | + const int bwidth = bottom[0]->width(); |
| 227 | + |
| 228 | + const int topcount = top_width_ * top_height_ * top_channels_; |
| 229 | + const int pheight = bheight + 2 * pad_size_; |
| 230 | + const int pwidth = bwidth + 2 * pad_size_; |
| 231 | + |
| 232 | + blob_rearrange_kernel2<Dtype>(rbot1_->mutable_cpu_data(), bottom[0]->cpu_data(), bnum, bchannels, bheight, bwidth, pheight, pwidth, pad_size_); |
| 233 | + blob_rearrange_kernel2<Dtype>(rbot2_->mutable_cpu_data(), bottom[1]->cpu_data(), bnum, bchannels, bheight, bwidth, pheight, pwidth, pad_size_); |
| 234 | + |
| 235 | + const int num = bnum; |
| 236 | + const int channels = bchannels; |
| 237 | + const int height = bheight + 2 * pad_size_; |
| 238 | + const int width = bwidth + 2 * pad_size_; |
| 239 | + |
| 240 | + if(corr_type_ == CorrelationParameter_CorrelationType_MULTIPLY) { |
| 241 | + // CorrelationLayer |
| 242 | + |
| 243 | + CorrelateData<Dtype>( |
| 244 | + num, top_width_, top_height_, top_channels_, topcount, |
| 245 | + max_displacement_, neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_, kernel_size_, |
| 246 | + stride1_, stride2_, |
| 247 | + width, height, channels, |
| 248 | + rbot1_->cpu_data(), rbot2_->cpu_data(), top[0]->mutable_cpu_data() |
| 249 | + ); |
| 250 | + |
| 251 | + } else if(corr_type_ == CorrelationParameter_CorrelationType_SUBTRACT) { |
| 252 | + // CorrelationLayer |
| 253 | + for(int n = 0; n < num; n++) { |
| 254 | + |
| 255 | + CorrelateDataSubtract<Dtype>( |
| 256 | + num, n, top_width_, top_height_, top_channels_, topcount, |
| 257 | + max_displacement_, neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_, |
| 258 | + stride1_, stride2_, |
| 259 | + width, height, channels, |
| 260 | + rbot1_->cpu_data(), rbot2_->cpu_data(), top[0]->mutable_cpu_data() |
| 261 | + ); |
| 262 | + } |
| 263 | + } |
90 | 264 | } |
91 | 265 |
|
92 | 266 | template <typename Dtype> |
|
0 commit comments