Skip to content

Commit c808660

Browse files
committed
rename the parameter shape as indices_shape for GatherNd
1 parent d56ffa3 commit c808660

File tree

3 files changed

+102
-102
lines changed

3 files changed

+102
-102
lines changed
Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,54 @@
11
#ifndef CAFFE_GATHER_ND_LAYER_HPP_
22
#define CAFFE_GATHER_ND_LAYER_HPP_
33

4-
#include <vector>
54
#include <string>
5+
#include <vector>
66

77
#include "caffe/blob.hpp"
88
#include "caffe/layer.hpp"
99
#include "caffe/proto/caffe.pb.h"
1010

11-
12-
1311
namespace caffe {
1412
/*
15-
* @brief Resize images to size using nearest neighbor interpolation. ////
16-
* Note: implementation of tf.gather_nd
17-
* https://www.tensorflow.org/api_docs/python/tf/gather_nd
18-
*/
13+
* @brief Resize images to size using nearest neighbor interpolation. ////
14+
* Note: implementation of tf.gather_nd
15+
* https://www.tensorflow.org/api_docs/python/tf/gather_nd
16+
*/
1917

20-
template <typename Dtype>
21-
class GatherNdLayer : public Layer<Dtype> {
18+
template <typename Dtype> class GatherNdLayer : public Layer<Dtype> {
2219
public:
23-
explicit GatherNdLayer(const LayerParameter& param)
24-
: Layer<Dtype>(param) {}
25-
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
26-
const vector<Blob<Dtype>*>& top);
27-
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
28-
const vector<Blob<Dtype>*>& top);
29-
30-
virtual inline const char* type() const { return "GatherNd"; }
31-
virtual inline int ExactNumBottomBlobs() const { return 1; }
32-
virtual inline int ExactNumTopBlobs() const { return 1; }
33-
34-
protected:
35-
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
36-
const vector<Blob<Dtype>*>& top);
20+
explicit GatherNdLayer(const LayerParameter &param) : Layer<Dtype>(param) {}
21+
virtual void LayerSetUp(const vector<Blob<Dtype> *> &bottom,
22+
const vector<Blob<Dtype> *> &top);
23+
virtual void Reshape(const vector<Blob<Dtype> *> &bottom,
24+
const vector<Blob<Dtype> *> &top);
25+
26+
virtual inline const char *type() const { return "GatherNd"; }
27+
virtual inline int ExactNumBottomBlobs() const { return 1; }
28+
virtual inline int ExactNumTopBlobs() const { return 1; }
29+
30+
protected:
31+
virtual void Forward_cpu(const vector<Blob<Dtype> *> &bottom,
32+
const vector<Blob<Dtype> *> &top);
3733
/// @brief Not implemented
38-
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
39-
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
40-
NOT_IMPLEMENTED;
34+
virtual void Backward_cpu(const vector<Blob<Dtype> *> &top,
35+
const vector<bool> &propagate_down,
36+
const vector<Blob<Dtype> *> &bottom) {
37+
NOT_IMPLEMENTED;
4138
}
42-
//virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
43-
// const vector<Blob<Dtype>*>& top) {}
44-
//virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
45-
// const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}
39+
// virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
40+
// const vector<Blob<Dtype>*>& top) {}
41+
// virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
42+
// const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom)
43+
// {}
4644

47-
int gather_nd_size_;
45+
int gather_nd_size_;
4846
int indices_dim_;
4947
int indices_N_;
50-
vector<int> indices_;
51-
vector<int> indices_shape_;
48+
vector<int> indices_;
49+
vector<int> indices_shape_;
5250
};
5351

54-
} // namespace caffe
55-
56-
#endif // CAFFE_GATHER_ND_LAYER_HPP_
52+
} // namespace caffe
5753

54+
#endif // CAFFE_GATHER_ND_LAYER_HPP_
Lines changed: 68 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,88 @@
1-
#include <vector>
21
#include <algorithm>
32
#include <cmath>
3+
#include <vector>
44

55
#include "caffe/layers/gather_nd_layer.hpp"
6-
#include "caffe/util/math_functions.hpp"
7-
8-
6+
#include "caffe/util/math_functions.hpp"
97

108
namespace caffe {
119

1210
template <typename Dtype>
13-
void GatherNdLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
14-
const vector<Blob<Dtype>*>& top) {
15-
const GatherNdParameter& gather_nd_param = this->layer_param_.gather_nd_param();
16-
indices_.clear();
17-
std::copy(gather_nd_param.indices().begin(),
18-
gather_nd_param.indices().end(),
19-
std::back_inserter(indices_));
20-
indices_shape_.clear();
21-
std::copy(gather_nd_param.shape().begin(),
22-
gather_nd_param.shape().end(),
23-
std::back_inserter(indices_shape_));
11+
void GatherNdLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
12+
const vector<Blob<Dtype> *> &top) {
13+
const GatherNdParameter &gather_nd_param =
14+
this->layer_param_.gather_nd_param();
15+
indices_.clear();
16+
std::copy(gather_nd_param.indices().begin(), gather_nd_param.indices().end(),
17+
std::back_inserter(indices_));
18+
indices_shape_.clear();
19+
std::copy(gather_nd_param.indices_shape().begin(), gather_nd_param.indices_shape().end(),
20+
std::back_inserter(indices_shape_));
2421
}
2522

2623
template <typename Dtype>
27-
void GatherNdLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
28-
const vector<Blob<Dtype>*>& top) {
29-
const int num_axes = bottom[0]->num_axes();
30-
CHECK_GE(num_axes, 1) << "the dimension of input should be larger than or equal to 1";
31-
//const GatherNdParameter& gather_nd_param = this->layer_param_.gather_nd_param();
32-
indices_dim_ = indices_shape_.size();
33-
CHECK_GE(indices_dim_, 1) << "the dimension of indices should be larger than or equal to 1";
34-
int count = 1;
35-
for (int i = 0; i < indices_shape_.size(); ++i) {
36-
count *= indices_shape_[i];
37-
}
38-
CHECK_EQ(indices_.size(), count) << "the size and shape of indices do not match" ;
39-
vector<int> bottom_shape = bottom[0]->shape();
40-
vector<int> top_shape = bottom[0]->shape();
41-
indices_N_ = indices_shape_[indices_shape_.size()-1];
42-
CHECK_LE(indices_N_, num_axes) << "indices.shape[-1] must be <= params.rank, but saw indices.shape[-1]:"
43-
<< indices_N_ << ", and params.rank: " << num_axes;
44-
top_shape.resize(indices_dim_ - 1 + num_axes - indices_N_);
45-
gather_nd_size_ = bottom[0]->count(indices_N_);
24+
void GatherNdLayer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
25+
const vector<Blob<Dtype> *> &top) {
26+
const int num_axes = bottom[0]->num_axes();
27+
CHECK_GE(num_axes, 1)
28+
<< "the dimension of input should be larger than or equal to 1";
29+
// const GatherNdParameter& gather_nd_param =
30+
// this->layer_param_.gather_nd_param();
31+
indices_dim_ = indices_shape_.size();
32+
CHECK_GE(indices_dim_, 1)
33+
<< "the dimension of indices should be larger than or equal to 1";
34+
int count = 1;
35+
for (int i = 0; i < indices_shape_.size(); ++i) {
36+
count *= indices_shape_[i];
37+
}
38+
CHECK_EQ(indices_.size(), count)
39+
<< "the size and shape of indices do not match";
40+
vector<int> bottom_shape = bottom[0]->shape();
41+
vector<int> top_shape = bottom[0]->shape();
42+
indices_N_ = indices_shape_[indices_shape_.size() - 1];
43+
CHECK_LE(indices_N_, num_axes)
44+
<< "indices.shape[-1] must be <= params.rank, but saw indices.shape[-1]:"
45+
<< indices_N_ << ", and params.rank: " << num_axes;
46+
top_shape.resize(indices_dim_ - 1 + num_axes - indices_N_);
47+
gather_nd_size_ = bottom[0]->count(indices_N_);
4648

47-
// The result shape is
48-
// indices.shape[:-1] + params.shape[indices.shape[-1]:]
49-
for (int i = 0; i < indices_.size(); ++i) {
50-
CHECK_GE(indices_[i], 0) << "indices_ element with idx" << i << " is negative";
51-
}
52-
for (int i = 0; i < indices_dim_ - 1; ++i) {
53-
top_shape[i] = indices_shape_[i];
54-
}
55-
for (int i = 0; i < num_axes - indices_N_; ++i) {
56-
top_shape[i + indices_dim_ - 1] = bottom_shape[i + indices_N_];
57-
}
58-
top[0]->Reshape(top_shape);
49+
// The result shape is
50+
// indices.shape[:-1] + params.shape[indices.shape[-1]:]
51+
for (int i = 0; i < indices_.size(); ++i) {
52+
CHECK_GE(indices_[i], 0)
53+
<< "indices_ element with idx" << i << " is negative";
54+
}
55+
for (int i = 0; i < indices_dim_ - 1; ++i) {
56+
top_shape[i] = indices_shape_[i];
57+
}
58+
for (int i = 0; i < num_axes - indices_N_; ++i) {
59+
top_shape[i + indices_dim_ - 1] = bottom_shape[i + indices_N_];
60+
}
61+
top[0]->Reshape(top_shape);
5962
}
6063

6164
template <typename Dtype>
62-
void GatherNdLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
63-
const vector<Blob<Dtype>*>& top) {
64-
const Dtype* bottom_data = bottom[0]->cpu_data();
65-
Dtype* top_data = top[0]->mutable_cpu_data();
66-
vector<int> bottom_shape = bottom[0]->shape();
67-
for (int m = 0; m < indices_.size()/indices_N_; ++m) {
68-
const int top_offset = m * gather_nd_size_;
69-
int bottom_offset = 0;
70-
for (int n = 0; n < indices_N_; ++n) {
71-
int indices_value = indices_[m*indices_N_ + n];
72-
int params_idx = bottom_shape[n];
73-
CHECK_LT(indices_value, params_idx) << "indices value does not index into param dimension: " << n;
74-
bottom_offset += indices_[m*indices_N_ + n] * bottom[0]->count(n + 1);
75-
}
76-
caffe_copy(gather_nd_size_,
77-
bottom_data + bottom_offset, top_data + top_offset);
78-
}
65+
void GatherNdLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
66+
const vector<Blob<Dtype> *> &top) {
67+
const Dtype *bottom_data = bottom[0]->cpu_data();
68+
Dtype *top_data = top[0]->mutable_cpu_data();
69+
vector<int> bottom_shape = bottom[0]->shape();
70+
for (int m = 0; m < indices_.size() / indices_N_; ++m) {
71+
const int top_offset = m * gather_nd_size_;
72+
int bottom_offset = 0;
73+
for (int n = 0; n < indices_N_; ++n) {
74+
int indices_value = indices_[m * indices_N_ + n];
75+
int params_idx = bottom_shape[n];
76+
CHECK_LT(indices_value, params_idx)
77+
<< "indices value does not index into param dimension: " << n;
78+
bottom_offset += indices_[m * indices_N_ + n] * bottom[0]->count(n + 1);
79+
}
80+
caffe_copy(gather_nd_size_, bottom_data + bottom_offset,
81+
top_data + top_offset);
82+
}
7983
}
8084

8185
INSTANTIATE_CLASS(GatherNdLayer);
8286
REGISTER_LAYER_CLASS(GatherNd);
8387

84-
85-
} // namespace caffe
88+
} // namespace caffe

src/caffe/proto/caffe.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2946,7 +2946,7 @@ message NMSGatherParameter {
29462946

29472947
message GatherNdParameter {
29482948
repeated uint32 indices = 1;
2949-
repeated uint32 shape = 2;
2949+
repeated uint32 indices_shape = 2;
29502950
}
29512951

29522952
message Where4Parameter {

0 commit comments

Comments
 (0)