Skip to content

Commit da5e794

Browse files
committed
change gather layer: 1 input -> 2 inputs
1 parent ca7e58b commit da5e794

File tree

3 files changed

+72
-104
lines changed

3 files changed

+72
-104
lines changed
Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,52 @@
11
#ifndef CAFFE_GATHER_LAYER_HPP_
22
#define CAFFE_GATHER_LAYER_HPP_
33

4-
#include <vector>
54
#include <string>
5+
#include <vector>
66

77
#include "caffe/blob.hpp"
88
#include "caffe/layer.hpp"
99
#include "caffe/proto/caffe.pb.h"
1010

1111
namespace caffe {
1212
/*
13-
* @brief Resize images to size using nearest neighbor interpolation. ////
14-
* Note: implementation of tf.gather
15-
* https://www.tensorflow.org/api_docs/python/tf/gather
16-
*/
13+
* @brief Resize images to size using nearest neighbor interpolation. ////
14+
* Note: implementation of tf.gather
15+
* https://www.tensorflow.org/api_docs/python/tf/gather
16+
*/
1717

18-
template <typename Dtype>
19-
class GatherLayer : public Layer<Dtype> {
18+
template <typename Dtype> class GatherLayer : public Layer<Dtype> {
2019
public:
21-
explicit GatherLayer(const LayerParameter& param)
22-
: Layer<Dtype>(param) {}
23-
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
24-
const vector<Blob<Dtype>*>& top);
25-
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
26-
const vector<Blob<Dtype>*>& top);
27-
28-
virtual inline const char* type() const { return "Gather"; }
29-
virtual inline int ExactNumBottomBlobs() const { return 1; }
30-
virtual inline int ExactNumTopBlobs() const { return 1; }
31-
32-
protected:
33-
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
34-
const vector<Blob<Dtype>*>& top);
20+
explicit GatherLayer(const LayerParameter &param) : Layer<Dtype>(param) {}
21+
virtual void LayerSetUp(const vector<Blob<Dtype> *> &bottom,
22+
const vector<Blob<Dtype> *> &top);
23+
virtual void Reshape(const vector<Blob<Dtype> *> &bottom,
24+
const vector<Blob<Dtype> *> &top);
25+
26+
virtual inline const char *type() const { return "Gather"; }
27+
virtual inline int ExactNumBottomBlobs() const { return 2; }
28+
virtual inline int ExactNumTopBlobs() const { return 1; }
29+
30+
protected:
31+
virtual void Forward_cpu(const vector<Blob<Dtype> *> &bottom,
32+
const vector<Blob<Dtype> *> &top);
3533
/// @brief Not implemented
36-
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
37-
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
38-
NOT_IMPLEMENTED;
34+
virtual void Backward_cpu(const vector<Blob<Dtype> *> &top,
35+
const vector<bool> &propagate_down,
36+
const vector<Blob<Dtype> *> &bottom) {
37+
NOT_IMPLEMENTED;
3938
}
40-
//virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
41-
// const vector<Blob<Dtype>*>& top) {}
42-
//virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
43-
// const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}
44-
45-
//int count_;
46-
int num_gather_;
47-
int gather_size_;
48-
int gather_axis_;
39+
// virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
40+
// const vector<Blob<Dtype>*>& top) {}
41+
// virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
42+
// const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom)
43+
// {}
44+
45+
int gather_axis_;
4946
int indices_dim_;
50-
vector<int> indices_;
5147
vector<int> indices_shape_;
5248
};
5349

54-
} // namespace caffe
55-
56-
#endif // CAFFE_GATHER_LAYER_HPP_
50+
} // namespace caffe
5751

52+
#endif // CAFFE_GATHER_LAYER_HPP_

src/caffe/layers/gather_layer.cpp

Lines changed: 39 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,69 @@
1-
#include <vector>
21
#include <algorithm>
32
#include <cmath>
3+
#include <vector>
44

55
#include "caffe/layers/gather_layer.hpp"
6-
#include "caffe/util/math_functions.hpp"
6+
#include "caffe/util/math_functions.hpp"
77

88
namespace caffe {
99

1010
template <typename Dtype>
11-
void GatherLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
12-
const vector<Blob<Dtype>*>& top) {
13-
const GatherParameter& gather_param = this->layer_param_.gather_param();
14-
indices_.clear();
15-
std::copy(gather_param.indices().begin(),
16-
gather_param.indices().end(),
17-
std::back_inserter(indices_));
18-
indices_shape_.clear();
19-
std::copy(gather_param.shape().begin(),
20-
gather_param.shape().end(),
21-
std::back_inserter(indices_shape_));
11+
void GatherLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
12+
const vector<Blob<Dtype> *> &top) {
13+
// Gather has 2 inputs: params, indices, 1 attribute:axis
14+
const GatherParameter &gather_param = this->layer_param_.gather_param();
15+
gather_axis_ = bottom[0]->CanonicalAxisIndex(gather_param.axis());
16+
indices_shape_ = bottom[1]->shape();
17+
CHECK_GE(bottom[0]->num_axes(), 1)
18+
<< "the dimension of input should be larger than or equal to 1";
2219
}
2320

2421
template <typename Dtype>
25-
void GatherLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
26-
const vector<Blob<Dtype>*>& top) {
27-
const int num_axes = bottom[0]->num_axes();
28-
CHECK_GE(num_axes, 1) << "the dimension of input should be larger than or equal to 1";
29-
const GatherParameter& gather_param = this->layer_param_.gather_param();
30-
gather_axis_ = bottom[0]->CanonicalAxisIndex(gather_param.axis());
31-
if (indices_shape_.size() == 1 && indices_shape_[0] == 0) {
32-
indices_dim_ = 0;
33-
CHECK_EQ(indices_.size(), 1) << "indices should be scalar!";
34-
}
35-
else {
36-
indices_dim_ = indices_shape_.size();
37-
int count = 1;
38-
for (int i = 0; i < indices_shape_.size(); ++i) {
39-
count *= indices_shape_[i];
40-
}
41-
CHECK_EQ(indices_.size(), count) << "the size and shape of indices do not match";
42-
}
43-
44-
// Initialize with the first blob
22+
void GatherLayer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
23+
const vector<Blob<Dtype> *> &top) {
24+
// Initialize with the first blob
4525
// The result shape is params.shape[-1:axis] + indices.shape +
4626
// params.shape[axis + 0:].
47-
vector<int> bottom_shape = bottom[0]->shape();
27+
const int indices_dim_ = bottom[1]->num_axes();
4828
vector<int> top_shape = bottom[0]->shape();
49-
top_shape.resize(bottom_shape.size() + indices_dim_ - 1);
50-
num_gather_ = bottom[0]->count(0, gather_axis_);
51-
gather_size_ = bottom[0]->count(gather_axis_ + 1);
52-
for (int i = 0; i < indices_.size(); ++i) {
53-
CHECK_GE(indices_[i], 0) << "indices_ element with idx" << i << " is negative";
54-
CHECK_LT(indices_[i], bottom[0]->shape(gather_axis_))
55-
<< "indices_ element with idx" << i << " is out of range "
56-
<< bottom[0]->shape(gather_axis_);
57-
}
58-
for (int i = 0; i < gather_axis_; ++i) {
59-
top_shape[i] = bottom_shape[i];
60-
}
61-
for (int i = 0; i < indices_dim_; ++i) {
62-
top_shape[i + gather_axis_] = indices_shape_[i];
63-
}
64-
for (int i = gather_axis_ + 1; i < num_axes; ++i) {
65-
top_shape[i + indices_dim_ - 1] = bottom_shape[i];
66-
}
29+
top_shape.erase(top_shape.begin() + gather_axis_);
30+
top_shape.insert(top_shape.begin() + gather_axis_, indices_shape_.begin(),
31+
indices_shape_.end());
6732
top[0]->Reshape(top_shape);
6833
}
6934

7035
template <typename Dtype>
71-
void GatherLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
72-
const vector<Blob<Dtype>*>& top) {
36+
void GatherLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
37+
const vector<Blob<Dtype> *> &top) {
7338
vector<int> bottom_shape = bottom[0]->shape();
74-
//const Dtype* params = bottom[0]->cpu_data();
75-
const Dtype* bottom_data = bottom[0]->cpu_data();
76-
Dtype* top_data = top[0]->mutable_cpu_data();
39+
const int num_gather_ = bottom[0]->count(0, gather_axis_);
40+
const int gather_size_ = bottom[0]->count(gather_axis_ + 1);
41+
const Dtype *bottom_data = bottom[0]->cpu_data();
42+
const Dtype *indices_ = bottom[1]->cpu_data();
43+
// check indices_
44+
for (int i = 0; i < bottom[1]->count(); ++i) {
45+
CHECK_GE(indices_[i], 0)
46+
<< "indices_ element with idx" << i << " is negative";
47+
CHECK_LT(indices_[i], bottom[0]->shape(gather_axis_))
48+
<< "indices_ element with idx" << i << " is out of range "
49+
<< bottom[0]->shape(gather_axis_);
50+
}
51+
Dtype *top_data = top[0]->mutable_cpu_data();
7752
const int bottom_gather_axis = bottom[0]->shape(gather_axis_);
7853
int num = 0;
7954
for (int m = 0; m < num_gather_; ++m) {
80-
for (int n = 0; n < indices_.size(); ++n) {
81-
const int top_offset = num * gather_size_;
55+
for (int n = 0; n < bottom[1]->count(); ++n) {
56+
const int top_offset = num * gather_size_;
8257
const int bottom_offset =
83-
(m * bottom_gather_axis + indices_[n]) * gather_size_;
84-
caffe_copy(gather_size_,
85-
bottom_data + bottom_offset, top_data + top_offset);
58+
(m * bottom_gather_axis + (int)indices_[n]) * gather_size_;
59+
caffe_copy(gather_size_, bottom_data + bottom_offset,
60+
top_data + top_offset);
8661
num += 1;
87-
}
62+
}
8863
}
8964
}
9065

9166
INSTANTIATE_CLASS(GatherLayer);
9267
REGISTER_LAYER_CLASS(Gather);
9368

94-
} // namespace caffe
69+
} // namespace caffe

src/caffe/proto/caffe.proto

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2947,9 +2947,7 @@ message ResizeNearestNeighborParameter {
29472947
}
29482948

29492949
message GatherParameter{
2950-
repeated uint32 indices = 1;
2951-
repeated uint32 shape = 2;
2952-
optional int32 axis = 3 [default = 0];
2950+
optional int32 axis = 1 [default = 0];
29532951
}
29542952

29552953
message TopkGatherParameter {

0 commit comments

Comments
 (0)