Skip to content

Commit 420d385

Browse files
committed
add hard_tanh_layer for keras custom activation
1 parent a900cd9 commit 420d385

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#ifndef CAFFE_HARD_TANH_LAYER_HPP_
2+
#define CAFFE_HARD_TANH_LAYER_HPP_
3+
4+
#include <vector>
5+
6+
#include "caffe/blob.hpp"
7+
#include "caffe/layer.hpp"
8+
#include "caffe/proto/caffe.pb.h"
9+
10+
#include "caffe/layers/neuron_layer.hpp"
11+
12+
namespace caffe {
13+
14+
// implement of hard_tanh activation
15+
16+
template <typename Dtype> class HardTanhLayer : public NeuronLayer<Dtype> {
17+
public:
18+
explicit HardTanhLayer(const LayerParameter &param)
19+
: NeuronLayer<Dtype>(param) {}
20+
21+
virtual inline const char *type() const { return "HardTanh"; }
22+
23+
protected:
24+
virtual void Forward_cpu(const vector<Blob<Dtype> *> &bottom,
25+
const vector<Blob<Dtype> *> &top);
26+
virtual void Backward_cpu(const vector<Blob<Dtype> *> &top,
27+
const vector<bool> &propagate_down,
28+
const vector<Blob<Dtype> *> &bottom) {
29+
for (int i = 0; i < propagate_down.size(); ++i) {
30+
if (propagate_down[i]) {
31+
NOT_IMPLEMENTED;
32+
}
33+
}
34+
}
35+
};
36+
37+
} // namespace caffe
38+
39+
#endif // CAFFE_HARD_TANH_LAYER_HPP_
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include <math.h>
2+
#include <vector>
3+
4+
#include "caffe/layers/hard_tanh_layer.hpp"
5+
6+
namespace caffe {
7+
8+
template <typename Dtype>
9+
void HardTanhLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
10+
const vector<Blob<Dtype> *> &top) {
11+
const Dtype *bottom_data = bottom[0]->cpu_data();
12+
Dtype *top_data = top[0]->mutable_cpu_data();
13+
for (int i = 0; i < bottom[0]->count(); ++i) {
14+
top_data[i] = (bottom_data[i] > -1) ? bottom_data[i] : Dtype(-1);
15+
top_data[i] = (bottom_data[i] > 1) ? Dtype(1) : top_data[i];
16+
}
17+
}
18+
19+
INSTANTIATE_CLASS(HardTanhLayer);
20+
REGISTER_LAYER_CLASS(HardTanh);
21+
22+
} // namespace caffe

0 commit comments

Comments
 (0)