|
| 1 | +#include <math.h> |
| 2 | +#include <vector> |
| 3 | + |
| 4 | +#include "caffe/layers/not_equal_layer.hpp" |
| 5 | + |
| 6 | +namespace caffe { |
| 7 | + |
| 8 | + template <typename Dtype> |
| 9 | + void NotEqualLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom, |
| 10 | + const vector<Blob<Dtype>*>& top) { |
| 11 | + |
| 12 | + Blob<Dtype>* comparand = (bottom.size() > 1) ? bottom[1] : this->blobs_[0].get(); |
| 13 | + vector<int> bottom_shape = bottom[0]->shape(); |
| 14 | + |
| 15 | + // case 1: bottom[1] is a scalar(bottom[0] may be a scalar) |
| 16 | + if (comparand->num_axes() == 0) {} |
| 17 | + // case 2: bottom[0] and bottom[1] are tensor and have the same dimension |
| 18 | + else if (comparand->num_axes() == bottom[0]->num_axes()) { |
| 19 | + for (int i = 0; i < bottom[0]->num_axes(); ++i) { |
| 20 | + CHECK_EQ(bottom[0]->shape(i), comparand->shape(i)) << "Broadcasting is not supported now!!! Please confirm that 2 inputs have the same shape!!"; |
| 21 | + } |
| 22 | + } |
| 23 | + // case 3: bottom[0] and bottom[1] are tensor and have different dimension/shape |
| 24 | + else { |
| 25 | + CHECK_EQ(bottom[0]->num_axes(), comparand->num_axes()) << "Broadcasting is not supported now!!! Please confirm that 2 inputs have the same shape!!"; |
| 26 | + } |
| 27 | + vector<int> top_shape = bottom_shape; |
| 28 | + top[0]->Reshape(top_shape); |
| 29 | + } |
| 30 | + |
| 31 | + template <typename Dtype> |
| 32 | + void NotEqualLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, |
| 33 | + const vector<Blob<Dtype>*>& top) { |
| 34 | + const Dtype* bottom_data = bottom[0]->cpu_data(); |
| 35 | + Blob<Dtype>* comparand = (bottom.size() > 1) ? bottom[1] : this->blobs_[0].get(); |
| 36 | + const Dtype* comparand_data = comparand->cpu_data(); |
| 37 | + Dtype* top_data = top[0]->mutable_cpu_data(); |
| 38 | + if (comparand->num_axes() == 0) { |
| 39 | + for (int i = 0; i < top[0]->count(); ++i) { |
| 40 | + top_data[i] = bool(bottom_data[i] != comparand_data[0]); |
| 41 | + } |
| 42 | + } |
| 43 | + else { |
| 44 | + for (int i = 0; i < top[0]->count(); ++i) { |
| 45 | + top_data[i] = bool(bottom_data[i] != comparand_data[i]); |
| 46 | + } |
| 47 | + } |
| 48 | + } |
| 49 | + |
| 50 | +INSTANTIATE_CLASS(NotEqualLayer); |
| 51 | +REGISTER_LAYER_CLASS(NotEqual); |
| 52 | + |
| 53 | +} // namespace caffe |
| 54 | + |
0 commit comments