Skip to content

Commit a0e7e86

Browse files
committed
support NotEqual layer without broadcasting
1 parent 7fd22a5 commit a0e7e86

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#ifndef CAFFE_NOT_EQUAL_LAYER_HPP_
2+
#define CAFFE_NOT_EQUAL_LAYER_HPP_
3+
4+
#include <vector>
5+
6+
#include "caffe/blob.hpp"
7+
#include "caffe/layer.hpp"
8+
#include "caffe/proto/caffe.pb.h"
9+
10+
namespace caffe {
11+
12+
// implement of Tensorflow Operator: https://www.tensorflow.org/api_docs/python/tf/math/not_equal
13+
14+
15+
template <typename Dtype> class NotEqualLayer : public Layer<Dtype> {
16+
public:
17+
explicit NotEqualLayer(const LayerParameter &param)
18+
: Layer<Dtype>(param) {}
19+
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
20+
const vector<Blob<Dtype>*>& top);
21+
22+
virtual inline const char *type() const { return "NotEqual"; }
23+
virtual inline int MinBottomBlobs() const { return 1; }
24+
virtual inline int ExactNumTopBlobs() const { return 1; }
25+
26+
protected:
27+
virtual void Forward_cpu(const vector<Blob<Dtype> *> &bottom,
28+
const vector<Blob<Dtype> *> &top);
29+
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
30+
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
31+
NOT_IMPLEMENTED;
32+
};
33+
};
34+
35+
} // namespace caffe
36+
37+
#endif // CAFFE_NOT_EQUAL_LAYER_HPP_
38+
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#include <math.h>
2+
#include <vector>
3+
4+
#include "caffe/layers/not_equal_layer.hpp"
5+
6+
namespace caffe {
7+
8+
template <typename Dtype>
9+
void NotEqualLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
10+
const vector<Blob<Dtype>*>& top) {
11+
12+
Blob<Dtype>* comparand = (bottom.size() > 1) ? bottom[1] : this->blobs_[0].get();
13+
vector<int> bottom_shape = bottom[0]->shape();
14+
15+
// case 1: bottom[1] is a scalar(bottom[0] may be a scalar)
16+
if (comparand->num_axes() == 0) {}
17+
// case 2: bottom[0] and bottom[1] are tensor and have the same dimension
18+
else if (comparand->num_axes() == bottom[0]->num_axes()) {
19+
for (int i = 0; i < bottom[0]->num_axes(); ++i) {
20+
CHECK_EQ(bottom[0]->shape(i), comparand->shape(i)) << "Broadcasting is not supported now!!! Please confirm that 2 inputs have the same shape!!";
21+
}
22+
}
23+
// case 3: bottom[0] and bottom[1] are tensor and have different dimension/shape
24+
else {
25+
CHECK_EQ(bottom[0]->num_axes(), comparand->num_axes()) << "Broadcasting is not supported now!!! Please confirm that 2 inputs have the same shape!!";
26+
}
27+
vector<int> top_shape = bottom_shape;
28+
top[0]->Reshape(top_shape);
29+
}
30+
31+
template <typename Dtype>
32+
void NotEqualLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
33+
const vector<Blob<Dtype>*>& top) {
34+
const Dtype* bottom_data = bottom[0]->cpu_data();
35+
Blob<Dtype>* comparand = (bottom.size() > 1) ? bottom[1] : this->blobs_[0].get();
36+
const Dtype* comparand_data = comparand->cpu_data();
37+
Dtype* top_data = top[0]->mutable_cpu_data();
38+
if (comparand->num_axes() == 0) {
39+
for (int i = 0; i < top[0]->count(); ++i) {
40+
top_data[i] = bool(bottom_data[i] != comparand_data[0]);
41+
}
42+
}
43+
else {
44+
for (int i = 0; i < top[0]->count(); ++i) {
45+
top_data[i] = bool(bottom_data[i] != comparand_data[i]);
46+
}
47+
}
48+
}
49+
50+
INSTANTIATE_CLASS(NotEqualLayer);
51+
REGISTER_LAYER_CLASS(NotEqual);
52+
53+
} // namespace caffe
54+

0 commit comments

Comments
 (0)