@@ -10,52 +10,77 @@ namespace caffe {
1010template <typename Dtype>
1111void GatherLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
1212 const vector<Blob<Dtype> *> &top) {
13- // Gather has 2 inputs: params, indices, 1 attribute:axis
1413 const GatherParameter &gather_param = this ->layer_param_ .gather_param ();
15- gather_axis_ = bottom[0 ]->CanonicalAxisIndex (gather_param.axis ());
16- indices_shape_ = bottom[1 ]->shape ();
17- CHECK_GE (bottom[0 ]->num_axes (), 1 )
18- << " the dimension of input should be larger than or equal to 1" ;
14+ indices_.clear ();
15+ std::copy (gather_param.indices ().begin (), gather_param.indices ().end (),
16+ std::back_inserter (indices_));
17+ indices_shape_.clear ();
18+ std::copy (gather_param.shape ().begin (), gather_param.shape ().end (),
19+ std::back_inserter (indices_shape_));
1920}
2021
2122template <typename Dtype>
2223void GatherLayer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
2324 const vector<Blob<Dtype> *> &top) {
25+ const int num_axes = bottom[0 ]->num_axes ();
26+ CHECK_GE (num_axes, 1 )
27+ << " the dimension of input should be larger than or equal to 1" ;
28+ const GatherParameter &gather_param = this ->layer_param_ .gather_param ();
29+ gather_axis_ = bottom[0 ]->CanonicalAxisIndex (gather_param.axis ());
30+ if (indices_shape_.size () == 1 && indices_shape_[0 ] == 0 ) {
31+ indices_dim_ = 0 ;
32+ CHECK_EQ (indices_.size (), 1 ) << " indices should be scalar!" ;
33+ } else {
34+ indices_dim_ = indices_shape_.size ();
35+ int count = 1 ;
36+ for (int i = 0 ; i < indices_shape_.size (); ++i) {
37+ count *= indices_shape_[i];
38+ }
39+ CHECK_EQ (indices_.size (), count)
40+ << " the size and shape of indices do not match" ;
41+ }
42+
2443 // Initialize with the first blob
2544 // The result shape is params.shape[-1:axis] + indices.shape +
2645 // params.shape[axis + 0:].
27- const int indices_dim_ = bottom[1 ]->num_axes ();
46+ vector< int > bottom_shape = bottom[0 ]->shape ();
2847 vector<int > top_shape = bottom[0 ]->shape ();
29- top_shape.erase (top_shape.begin () + gather_axis_);
30- top_shape.insert (top_shape.begin () + gather_axis_, indices_shape_.begin (),
31- indices_shape_.end ());
48+ top_shape.resize (bottom_shape.size () + indices_dim_ - 1 );
49+ num_gather_ = bottom[0 ]->count (0 , gather_axis_);
50+ gather_size_ = bottom[0 ]->count (gather_axis_ + 1 );
51+ for (int i = 0 ; i < indices_.size (); ++i) {
52+ CHECK_GE (indices_[i], 0 )
53+ << " indices_ element with idx" << i << " is negative" ;
54+ CHECK_LT (indices_[i], bottom[0 ]->shape (gather_axis_))
55+ << " indices_ element with idx" << i << " is out of range "
56+ << bottom[0 ]->shape (gather_axis_);
57+ }
58+ for (int i = 0 ; i < gather_axis_; ++i) {
59+ top_shape[i] = bottom_shape[i];
60+ }
61+ for (int i = 0 ; i < indices_dim_; ++i) {
62+ top_shape[i + gather_axis_] = indices_shape_[i];
63+ }
64+ for (int i = gather_axis_ + 1 ; i < num_axes; ++i) {
65+ top_shape[i + indices_dim_ - 1 ] = bottom_shape[i];
66+ }
3267 top[0 ]->Reshape (top_shape);
3368}
3469
3570template <typename Dtype>
3671void GatherLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
3772 const vector<Blob<Dtype> *> &top) {
3873 vector<int > bottom_shape = bottom[0 ]->shape ();
39- const int num_gather_ = bottom[0 ]->count (0 , gather_axis_);
40- const int gather_size_ = bottom[0 ]->count (gather_axis_ + 1 );
74+ // const Dtype* params = bottom[0]->cpu_data();
4175 const Dtype *bottom_data = bottom[0 ]->cpu_data ();
42- const Dtype *indices_ = bottom[1 ]->cpu_data ();
43- // check indices_
44- for (int i = 0 ; i < bottom[1 ]->count (); ++i) {
45- CHECK_GE (indices_[i], 0 )
46- << " indices_ element with idx" << i << " is negative" ;
47- CHECK_LT (indices_[i], bottom[0 ]->shape (gather_axis_))
48- << " indices_ element with idx" << i << " is out of range "
49- << bottom[0 ]->shape (gather_axis_);
50- }
5176 Dtype *top_data = top[0 ]->mutable_cpu_data ();
5277 const int bottom_gather_axis = bottom[0 ]->shape (gather_axis_);
5378 int num = 0 ;
5479 for (int m = 0 ; m < num_gather_; ++m) {
55- for (int n = 0 ; n < bottom[ 1 ]-> count (); ++n) {
80+ for (int n = 0 ; n < indices_. size (); ++n) {
5681 const int top_offset = num * gather_size_;
5782 const int bottom_offset =
58- (m * bottom_gather_axis + ( int ) indices_[n]) * gather_size_;
83+ (m * bottom_gather_axis + indices_[n]) * gather_size_;
5984 caffe_copy (gather_size_, bottom_data + bottom_offset,
6085 top_data + top_offset);
6186 num += 1 ;
0 commit comments