Skip to content

Commit d9bc0bc

Browse files
committed
refine NotEqual to remove Parameter layer
1 parent 7a9eb7d commit d9bc0bc

File tree

3 files changed

+39
-21
lines changed

3 files changed

+39
-21
lines changed

include/caffe/layers/not_equal_layer.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ template <typename Dtype> class NotEqualLayer : public Layer<Dtype> {
1616
public:
1717
explicit NotEqualLayer(const LayerParameter &param)
1818
: Layer<Dtype>(param) {}
19+
virtual void LayerSetUp(const vector<Blob<Dtype> *> &bottom,
20+
const vector<Blob<Dtype> *> &top);
1921
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
2022
const vector<Blob<Dtype>*>& top);
2123

@@ -30,6 +32,9 @@ template <typename Dtype> class NotEqualLayer : public Layer<Dtype> {
3032
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
3133
NOT_IMPLEMENTED;
3234
};
35+
36+
float comparand_;
37+
int const_flag_;
3338
};
3439

3540
} // namespace caffe

src/caffe/layers/not_equal_layer.cpp

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,49 @@
66
namespace caffe {
77

88
template <typename Dtype>
9-
void NotEqualLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
10-
const vector<Blob<Dtype>*>& top) {
11-
12-
Blob<Dtype>* comparand = (bottom.size() > 1) ? bottom[1] : this->blobs_[0].get();
13-
vector<int> bottom_shape = bottom[0]->shape();
14-
15-
// case 1: bottom[1] is a scalar(bottom[0] may be a scalar)
16-
if (comparand->num_axes() == 0) {}
17-
// case 2: bottom[0] and bottom[1] are tensor and have the same dimension
18-
else if (comparand->num_axes() == bottom[0]->num_axes()) {
19-
for (int i = 0; i < bottom[0]->num_axes(); ++i) {
20-
CHECK_EQ(bottom[0]->shape(i), comparand->shape(i)) << "Broadcasting is not supported now!!! Please confirm that 2 inputs have the same shape!!";
21-
}
9+
void NotEqualLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
10+
const vector<Blob<Dtype> *> &top) {
11+
if ((bottom.size() == 1) && (this->blobs_.size() == 0)) {
12+
const NotEqualParameter &not_equal_param = this->layer_param_.not_equal_param();
13+
comparand_ = not_equal_param.comparand();
14+
const_flag_ = 1;
2215
}
23-
// case 3: bottom[0] and bottom[1] are tensor and have different dimension/shape
2416
else {
25-
CHECK_EQ(bottom[0]->num_axes(), comparand->num_axes()) << "Broadcasting is not supported now!!! Please confirm that 2 inputs have the same shape!!";
17+
const_flag_ = 0;
18+
Blob<Dtype>* comparand = (bottom.size() > 1) ? bottom[1] : this->blobs_[0].get();
19+
// case 1: bottom[0] and bottom[1] are tensor and have the same dimension
20+
if (comparand->num_axes() == bottom[0]->num_axes()) {
21+
for (int i = 0; i < bottom[0]->num_axes(); ++i) {
22+
CHECK_EQ(bottom[0]->shape(i), comparand->shape(i)) << "Broadcasting is not supported now!!! Please confirm that 2 inputs have the same shape!!";
23+
}
24+
}
25+
// case 2: bottom[0] and bottom[1] are tensor and have different dimension/shape
26+
else {
27+
CHECK_EQ(bottom[0]->num_axes(), comparand->num_axes()) << "Broadcasting is not supported now!!! Please confirm that 2 inputs have the same shape!!";
28+
}
2629
}
27-
vector<int> top_shape = bottom_shape;
28-
top[0]->Reshape(top_shape);
30+
}
31+
32+
template <typename Dtype>
33+
void NotEqualLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
34+
const vector<Blob<Dtype>*>& top) {
35+
vector<int> bottom_shape = bottom[0]->shape();
36+
top[0]->Reshape(bottom_shape);
2937
}
3038

3139
template <typename Dtype>
3240
void NotEqualLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
3341
const vector<Blob<Dtype>*>& top) {
3442
const Dtype* bottom_data = bottom[0]->cpu_data();
35-
Blob<Dtype>* comparand = (bottom.size() > 1) ? bottom[1] : this->blobs_[0].get();
36-
const Dtype* comparand_data = comparand->cpu_data();
3743
Dtype* top_data = top[0]->mutable_cpu_data();
38-
if (comparand->num_axes() == 0) {
44+
if (const_flag_ == 1) {
3945
for (int i = 0; i < top[0]->count(); ++i) {
40-
top_data[i] = bool(bottom_data[i] != comparand_data[0]);
46+
top_data[i] = bool(bottom_data[i] != comparand_);
4147
}
4248
}
4349
else {
50+
Blob<Dtype>* comparand = (bottom.size() > 1) ? bottom[1] : this->blobs_[0].get();
51+
const Dtype* comparand_data = comparand->cpu_data();
4452
for (int i = 0; i < top[0]->count(); ++i) {
4553
top_data[i] = bool(bottom_data[i] != comparand_data[i]);
4654
}

src/caffe/proto/caffe.proto

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,7 @@ message LayerParameter {
673673
optional AttentionParameter attention_param = 270;
674674
optional RNNv2Parameter rnn_v2_param = 272;
675675
optional CountNonzeroParameter count_nonzero_param = 273;
676+
optional NotEqualParameter not_equal_param = 274;
676677

677678
//ONNX related
678679
optional NonMaxSuppressionParameter non_max_suppression_param = 271;
@@ -3377,3 +3378,7 @@ message CountNonzeroParameter {
33773378
repeated int32 axis = 1;
33783379
optional bool keepdims = 2[default = false];
33793380
}
3381+
3382+
message NotEqualParameter {
3383+
optional float comparand = 1;
3384+
}

0 commit comments

Comments
 (0)