Skip to content

Commit 1e93e85

Browse files
committed
keep old Gather and rename changed Gather as GatherV2
1 parent da5e794 commit 1e93e85

File tree

5 files changed

+183
-24
lines changed

5 files changed

+183
-24
lines changed

include/caffe/layers/gather_layer.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ template <typename Dtype> class GatherLayer : public Layer<Dtype> {
2424
const vector<Blob<Dtype> *> &top);
2525

2626
virtual inline const char *type() const { return "Gather"; }
27-
virtual inline int ExactNumBottomBlobs() const { return 2; }
27+
virtual inline int ExactNumBottomBlobs() const { return 1; }
2828
virtual inline int ExactNumTopBlobs() const { return 1; }
2929

3030
protected:
@@ -42,8 +42,12 @@ template <typename Dtype> class GatherLayer : public Layer<Dtype> {
4242
// const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom)
4343
// {}
4444

45+
// int count_;
46+
int num_gather_;
47+
int gather_size_;
4548
int gather_axis_;
4649
int indices_dim_;
50+
vector<int> indices_;
4751
vector<int> indices_shape_;
4852
};
4953

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#ifndef CAFFE_GATHER_V2_LAYER_HPP_
2+
#define CAFFE_GATHER_V2_LAYER_HPP_
3+
4+
#include <string>
5+
#include <vector>
6+
7+
#include "caffe/blob.hpp"
8+
#include "caffe/layer.hpp"
9+
#include "caffe/proto/caffe.pb.h"
10+
11+
namespace caffe {
12+
/*
13+
* @brief Resize images to size using nearest neighbor interpolation. ////
14+
* Note: another implementation of tf.gather
15+
* https://www.tensorflow.org/api_docs/python/tf/gather
16+
* In GatherV2, params and indices are inputs, axis is attribute
17+
*/
18+
19+
template <typename Dtype> class GatherV2Layer : public Layer<Dtype> {
20+
public:
21+
explicit GatherV2Layer(const LayerParameter &param) : Layer<Dtype>(param) {}
22+
virtual void LayerSetUp(const vector<Blob<Dtype> *> &bottom,
23+
const vector<Blob<Dtype> *> &top);
24+
virtual void Reshape(const vector<Blob<Dtype> *> &bottom,
25+
const vector<Blob<Dtype> *> &top);
26+
27+
virtual inline const char *type() const { return "GatherV2"; }
28+
virtual inline int ExactNumBottomBlobs() const { return 2; }
29+
virtual inline int ExactNumTopBlobs() const { return 1; }
30+
31+
protected:
32+
virtual void Forward_cpu(const vector<Blob<Dtype> *> &bottom,
33+
const vector<Blob<Dtype> *> &top);
34+
/// @brief Not implemented
35+
virtual void Backward_cpu(const vector<Blob<Dtype> *> &top,
36+
const vector<bool> &propagate_down,
37+
const vector<Blob<Dtype> *> &bottom) {
38+
NOT_IMPLEMENTED;
39+
}
40+
// virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
41+
// const vector<Blob<Dtype>*>& top) {}
42+
// virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
43+
// const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom)
44+
// {}
45+
46+
int gather_axis_;
47+
int indices_dim_;
48+
vector<int> indices_shape_;
49+
};
50+
51+
} // namespace caffe
52+
53+
#endif // CAFFE_GATHER_V2_LAYER_HPP_

src/caffe/layers/gather_layer.cpp

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,52 +10,77 @@ namespace caffe {
1010
template <typename Dtype>
1111
void GatherLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
1212
const vector<Blob<Dtype> *> &top) {
13-
// Gather has 2 inputs: params, indices, 1 attribute:axis
1413
const GatherParameter &gather_param = this->layer_param_.gather_param();
15-
gather_axis_ = bottom[0]->CanonicalAxisIndex(gather_param.axis());
16-
indices_shape_ = bottom[1]->shape();
17-
CHECK_GE(bottom[0]->num_axes(), 1)
18-
<< "the dimension of input should be larger than or equal to 1";
14+
indices_.clear();
15+
std::copy(gather_param.indices().begin(), gather_param.indices().end(),
16+
std::back_inserter(indices_));
17+
indices_shape_.clear();
18+
std::copy(gather_param.shape().begin(), gather_param.shape().end(),
19+
std::back_inserter(indices_shape_));
1920
}
2021

2122
template <typename Dtype>
2223
void GatherLayer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
2324
const vector<Blob<Dtype> *> &top) {
25+
const int num_axes = bottom[0]->num_axes();
26+
CHECK_GE(num_axes, 1)
27+
<< "the dimension of input should be larger than or equal to 1";
28+
const GatherParameter &gather_param = this->layer_param_.gather_param();
29+
gather_axis_ = bottom[0]->CanonicalAxisIndex(gather_param.axis());
30+
if (indices_shape_.size() == 1 && indices_shape_[0] == 0) {
31+
indices_dim_ = 0;
32+
CHECK_EQ(indices_.size(), 1) << "indices should be scalar!";
33+
} else {
34+
indices_dim_ = indices_shape_.size();
35+
int count = 1;
36+
for (int i = 0; i < indices_shape_.size(); ++i) {
37+
count *= indices_shape_[i];
38+
}
39+
CHECK_EQ(indices_.size(), count)
40+
<< "the size and shape of indices do not match";
41+
}
42+
2443
// Initialize with the first blob
2544
// The result shape is params.shape[-1:axis] + indices.shape +
2645
// params.shape[axis + 0:].
27-
const int indices_dim_ = bottom[1]->num_axes();
46+
vector<int> bottom_shape = bottom[0]->shape();
2847
vector<int> top_shape = bottom[0]->shape();
29-
top_shape.erase(top_shape.begin() + gather_axis_);
30-
top_shape.insert(top_shape.begin() + gather_axis_, indices_shape_.begin(),
31-
indices_shape_.end());
48+
top_shape.resize(bottom_shape.size() + indices_dim_ - 1);
49+
num_gather_ = bottom[0]->count(0, gather_axis_);
50+
gather_size_ = bottom[0]->count(gather_axis_ + 1);
51+
for (int i = 0; i < indices_.size(); ++i) {
52+
CHECK_GE(indices_[i], 0)
53+
<< "indices_ element with idx" << i << " is negative";
54+
CHECK_LT(indices_[i], bottom[0]->shape(gather_axis_))
55+
<< "indices_ element with idx" << i << " is out of range "
56+
<< bottom[0]->shape(gather_axis_);
57+
}
58+
for (int i = 0; i < gather_axis_; ++i) {
59+
top_shape[i] = bottom_shape[i];
60+
}
61+
for (int i = 0; i < indices_dim_; ++i) {
62+
top_shape[i + gather_axis_] = indices_shape_[i];
63+
}
64+
for (int i = gather_axis_ + 1; i < num_axes; ++i) {
65+
top_shape[i + indices_dim_ - 1] = bottom_shape[i];
66+
}
3267
top[0]->Reshape(top_shape);
3368
}
3469

3570
template <typename Dtype>
3671
void GatherLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
3772
const vector<Blob<Dtype> *> &top) {
3873
vector<int> bottom_shape = bottom[0]->shape();
39-
const int num_gather_ = bottom[0]->count(0, gather_axis_);
40-
const int gather_size_ = bottom[0]->count(gather_axis_ + 1);
74+
// const Dtype* params = bottom[0]->cpu_data();
4175
const Dtype *bottom_data = bottom[0]->cpu_data();
42-
const Dtype *indices_ = bottom[1]->cpu_data();
43-
// check indices_
44-
for (int i = 0; i < bottom[1]->count(); ++i) {
45-
CHECK_GE(indices_[i], 0)
46-
<< "indices_ element with idx" << i << " is negative";
47-
CHECK_LT(indices_[i], bottom[0]->shape(gather_axis_))
48-
<< "indices_ element with idx" << i << " is out of range "
49-
<< bottom[0]->shape(gather_axis_);
50-
}
5176
Dtype *top_data = top[0]->mutable_cpu_data();
5277
const int bottom_gather_axis = bottom[0]->shape(gather_axis_);
5378
int num = 0;
5479
for (int m = 0; m < num_gather_; ++m) {
55-
for (int n = 0; n < bottom[1]->count(); ++n) {
80+
for (int n = 0; n < indices_.size(); ++n) {
5681
const int top_offset = num * gather_size_;
5782
const int bottom_offset =
58-
(m * bottom_gather_axis + (int)indices_[n]) * gather_size_;
83+
(m * bottom_gather_axis + indices_[n]) * gather_size_;
5984
caffe_copy(gather_size_, bottom_data + bottom_offset,
6085
top_data + top_offset);
6186
num += 1;
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#include <algorithm>
2+
#include <cmath>
3+
#include <vector>
4+
5+
#include "caffe/layers/gather_v2_layer.hpp"
6+
#include "caffe/util/math_functions.hpp"
7+
8+
namespace caffe {
9+
10+
template <typename Dtype>
11+
void GatherV2Layer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
12+
const vector<Blob<Dtype> *> &top) {
13+
// GatherV2 has 2 inputs: params, indices, 1 attribute:axis
14+
const GatherV2Parameter &gather_v2_param =
15+
this->layer_param_.gather_v2_param();
16+
gather_axis_ = bottom[0]->CanonicalAxisIndex(gather_v2_param.axis());
17+
indices_shape_ = bottom[1]->shape();
18+
CHECK_GE(bottom[0]->num_axes(), 1)
19+
<< "the dimension of input should be larger than or equal to 1";
20+
}
21+
22+
template <typename Dtype>
23+
void GatherV2Layer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
24+
const vector<Blob<Dtype> *> &top) {
25+
// Initialize with the first blob
26+
// The result shape is params.shape[-1:axis] + indices.shape +
27+
// params.shape[axis + 0:].
28+
const int indices_dim_ = bottom[1]->num_axes();
29+
vector<int> top_shape = bottom[0]->shape();
30+
top_shape.erase(top_shape.begin() + gather_axis_);
31+
top_shape.insert(top_shape.begin() + gather_axis_, indices_shape_.begin(),
32+
indices_shape_.end());
33+
top[0]->Reshape(top_shape);
34+
}
35+
36+
template <typename Dtype>
37+
void GatherV2Layer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
38+
const vector<Blob<Dtype> *> &top) {
39+
vector<int> bottom_shape = bottom[0]->shape();
40+
const int num_gather_ = bottom[0]->count(0, gather_axis_);
41+
const int gather_size_ = bottom[0]->count(gather_axis_ + 1);
42+
const Dtype *bottom_data = bottom[0]->cpu_data();
43+
const Dtype *indices_ = bottom[1]->cpu_data();
44+
// check indices_
45+
for (int i = 0; i < bottom[1]->count(); ++i) {
46+
CHECK_GE(indices_[i], 0)
47+
<< "indices_ element with idx" << i << " is negative";
48+
CHECK_LT(indices_[i], bottom[0]->shape(gather_axis_))
49+
<< "indices_ element with idx" << i << " is out of range "
50+
<< bottom[0]->shape(gather_axis_);
51+
}
52+
Dtype *top_data = top[0]->mutable_cpu_data();
53+
const int bottom_gather_axis = bottom[0]->shape(gather_axis_);
54+
int num = 0;
55+
for (int m = 0; m < num_gather_; ++m) {
56+
for (int n = 0; n < bottom[1]->count(); ++n) {
57+
const int top_offset = num * gather_size_;
58+
const int bottom_offset =
59+
(m * bottom_gather_axis + (int)indices_[n]) * gather_size_;
60+
caffe_copy(gather_size_, bottom_data + bottom_offset,
61+
top_data + top_offset);
62+
num += 1;
63+
}
64+
}
65+
}
66+
67+
INSTANTIATE_CLASS(GatherV2Layer);
68+
REGISTER_LAYER_CLASS(GatherV2);
69+
70+
} // namespace caffe

src/caffe/proto/caffe.proto

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,7 @@ message LayerParameter {
664664
optional ReverseParameter reverse_param = 261;
665665
optional LpNormalizationParameter lp_normalization_param = 262;
666666
optional MatMulParameter matmul_param = 263;
667+
optional GatherV2Parameter gather_v2_param = 264;
667668
}
668669

669670
message AccumParameter {
@@ -2947,7 +2948,9 @@ message ResizeNearestNeighborParameter {
29472948
}
29482949

29492950
message GatherParameter{
2950-
optional int32 axis = 1 [default = 0];
2951+
repeated uint32 indices = 1;
2952+
repeated uint32 shape = 2;
2953+
optional int32 axis = 3 [default = 0];
29512954
}
29522955

29532956
message TopkGatherParameter {
@@ -3221,3 +3224,7 @@ message MatMulParameter {
32213224
optional bool transpose_a = 1[default = false];
32223225
optional bool transpose_b = 2[default = false];
32233226
}
3227+
3228+
message GatherV2Parameter {
3229+
optional int32 axis = 1[default = 0];
3230+
}

0 commit comments

Comments
 (0)