|
1 | | -#include <vector> |
2 | 1 | #include <algorithm> |
3 | 2 | #include <cmath> |
| 3 | +#include <vector> |
4 | 4 |
|
5 | 5 | #include "caffe/layers/gather_layer.hpp" |
6 | | -#include "caffe/util/math_functions.hpp" |
| 6 | +#include "caffe/util/math_functions.hpp" |
7 | 7 |
|
8 | 8 | namespace caffe { |
9 | 9 |
|
10 | 10 | template <typename Dtype> |
11 | | -void GatherLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, |
12 | | - const vector<Blob<Dtype>*>& top) { |
13 | | - const GatherParameter& gather_param = this->layer_param_.gather_param(); |
14 | | - indices_.clear(); |
15 | | - std::copy(gather_param.indices().begin(), |
16 | | - gather_param.indices().end(), |
17 | | - std::back_inserter(indices_)); |
18 | | - indices_shape_.clear(); |
19 | | - std::copy(gather_param.shape().begin(), |
20 | | - gather_param.shape().end(), |
21 | | - std::back_inserter(indices_shape_)); |
| 11 | +void GatherLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom, |
| 12 | + const vector<Blob<Dtype> *> &top) { |
| 13 | + // Gather has 2 inputs: params, indices, 1 attribute:axis |
| 14 | + const GatherParameter &gather_param = this->layer_param_.gather_param(); |
| 15 | + gather_axis_ = bottom[0]->CanonicalAxisIndex(gather_param.axis()); |
| 16 | + indices_shape_ = bottom[1]->shape(); |
| 17 | + CHECK_GE(bottom[0]->num_axes(), 1) |
| 18 | + << "the dimension of input should be larger than or equal to 1"; |
22 | 19 | } |
23 | 20 |
|
24 | 21 | template <typename Dtype> |
25 | | -void GatherLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom, |
26 | | - const vector<Blob<Dtype>*>& top) { |
27 | | - const int num_axes = bottom[0]->num_axes(); |
28 | | - CHECK_GE(num_axes, 1) << "the dimension of input should be larger than or equal to 1"; |
29 | | - const GatherParameter& gather_param = this->layer_param_.gather_param(); |
30 | | - gather_axis_ = bottom[0]->CanonicalAxisIndex(gather_param.axis()); |
31 | | - if (indices_shape_.size() == 1 && indices_shape_[0] == 0) { |
32 | | - indices_dim_ = 0; |
33 | | - CHECK_EQ(indices_.size(), 1) << "indices should be scalar!"; |
34 | | - } |
35 | | - else { |
36 | | - indices_dim_ = indices_shape_.size(); |
37 | | - int count = 1; |
38 | | - for (int i = 0; i < indices_shape_.size(); ++i) { |
39 | | - count *= indices_shape_[i]; |
40 | | - } |
41 | | - CHECK_EQ(indices_.size(), count) << "the size and shape of indices do not match"; |
42 | | - } |
43 | | - |
44 | | - // Initialize with the first blob |
| 22 | +void GatherLayer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom, |
| 23 | + const vector<Blob<Dtype> *> &top) { |
| 24 | + // Initialize with the first blob |
45 | 25 | // The result shape is params.shape[-1:axis] + indices.shape + |
46 | 26 | // params.shape[axis + 0:]. |
47 | | - vector<int> bottom_shape = bottom[0]->shape(); |
| 27 | + const int indices_dim_ = bottom[1]->num_axes(); |
48 | 28 | vector<int> top_shape = bottom[0]->shape(); |
49 | | - top_shape.resize(bottom_shape.size() + indices_dim_ - 1); |
50 | | - num_gather_ = bottom[0]->count(0, gather_axis_); |
51 | | - gather_size_ = bottom[0]->count(gather_axis_ + 1); |
52 | | - for (int i = 0; i < indices_.size(); ++i) { |
53 | | - CHECK_GE(indices_[i], 0) << "indices_ element with idx" << i << " is negative"; |
54 | | - CHECK_LT(indices_[i], bottom[0]->shape(gather_axis_)) |
55 | | - << "indices_ element with idx" << i << " is out of range " |
56 | | - << bottom[0]->shape(gather_axis_); |
57 | | - } |
58 | | - for (int i = 0; i < gather_axis_; ++i) { |
59 | | - top_shape[i] = bottom_shape[i]; |
60 | | - } |
61 | | - for (int i = 0; i < indices_dim_; ++i) { |
62 | | - top_shape[i + gather_axis_] = indices_shape_[i]; |
63 | | - } |
64 | | - for (int i = gather_axis_ + 1; i < num_axes; ++i) { |
65 | | - top_shape[i + indices_dim_ - 1] = bottom_shape[i]; |
66 | | - } |
| 29 | + top_shape.erase(top_shape.begin() + gather_axis_); |
| 30 | + top_shape.insert(top_shape.begin() + gather_axis_, indices_shape_.begin(), |
| 31 | + indices_shape_.end()); |
67 | 32 | top[0]->Reshape(top_shape); |
68 | 33 | } |
69 | 34 |
|
70 | 35 | template <typename Dtype> |
71 | | -void GatherLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, |
72 | | - const vector<Blob<Dtype>*>& top) { |
| 36 | +void GatherLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom, |
| 37 | + const vector<Blob<Dtype> *> &top) { |
73 | 38 | vector<int> bottom_shape = bottom[0]->shape(); |
74 | | - //const Dtype* params = bottom[0]->cpu_data(); |
75 | | - const Dtype* bottom_data = bottom[0]->cpu_data(); |
76 | | - Dtype* top_data = top[0]->mutable_cpu_data(); |
| 39 | + const int num_gather_ = bottom[0]->count(0, gather_axis_); |
| 40 | + const int gather_size_ = bottom[0]->count(gather_axis_ + 1); |
| 41 | + const Dtype *bottom_data = bottom[0]->cpu_data(); |
| 42 | + const Dtype *indices_ = bottom[1]->cpu_data(); |
| 43 | + // check indices_ |
| 44 | + for (int i = 0; i < bottom[1]->count(); ++i) { |
| 45 | + CHECK_GE(indices_[i], 0) |
| 46 | + << "indices_ element with idx" << i << " is negative"; |
| 47 | + CHECK_LT(indices_[i], bottom[0]->shape(gather_axis_)) |
| 48 | + << "indices_ element with idx" << i << " is out of range " |
| 49 | + << bottom[0]->shape(gather_axis_); |
| 50 | + } |
| 51 | + Dtype *top_data = top[0]->mutable_cpu_data(); |
77 | 52 | const int bottom_gather_axis = bottom[0]->shape(gather_axis_); |
78 | 53 | int num = 0; |
79 | 54 | for (int m = 0; m < num_gather_; ++m) { |
80 | | - for (int n = 0; n < indices_.size(); ++n) { |
81 | | - const int top_offset = num * gather_size_; |
| 55 | + for (int n = 0; n < bottom[1]->count(); ++n) { |
| 56 | + const int top_offset = num * gather_size_; |
82 | 57 | const int bottom_offset = |
83 | | - (m * bottom_gather_axis + indices_[n]) * gather_size_; |
84 | | - caffe_copy(gather_size_, |
85 | | - bottom_data + bottom_offset, top_data + top_offset); |
| 58 | + (m * bottom_gather_axis + (int)indices_[n]) * gather_size_; |
| 59 | + caffe_copy(gather_size_, bottom_data + bottom_offset, |
| 60 | + top_data + top_offset); |
86 | 61 | num += 1; |
87 | | - } |
| 62 | + } |
88 | 63 | } |
89 | 64 | } |
90 | 65 |
|
91 | 66 | INSTANTIATE_CLASS(GatherLayer); |
92 | 67 | REGISTER_LAYER_CLASS(Gather); |
93 | 68 |
|
94 | | -} // namespace caffe |
| 69 | +} // namespace caffe |
0 commit comments