Skip to content

Commit d41823e

Browse files
committed
add SimpleRNN layer(refer to ONNX RNN)
1 parent 6726fb3 commit d41823e

File tree

2 files changed

+353
-0
lines changed

2 files changed

+353
-0
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#ifndef CAFFE_SIMPLE_RNN_LAYER_HPP_
2+
#define CAFFE_SIMPLE_RNN_LAYER_HPP_
3+
4+
#include <string>
5+
#include <utility>
6+
#include <vector>
7+
8+
#include "caffe/blob.hpp"
9+
#include "caffe/common.hpp"
10+
#include "caffe/layer.hpp"
11+
#include "caffe/layers/recurrent_layer.hpp"
12+
#include "caffe/net.hpp"
13+
#include "caffe/proto/caffe.pb.h"
14+
15+
namespace caffe {
16+
17+
template <typename Dtype> class RecurrentLayer;
18+
19+
/**
20+
21+
* ONNX specification
22+
* Notations:
23+
* X - input tensor
24+
*
25+
* i - input gate
26+
*
27+
* t - time step (t-1 means previous time step)
28+
*
29+
* W[i] - W parameter weight matrix for input gates
30+
*
31+
* R[i] - R recurrence weight matrix for input gates
32+
*
33+
* Wb[i] - W bias vectors for input gates
34+
*
35+
* Rb[i] - R bias vectors for input gates
36+
*
37+
* WB[i] - W parameter weight matrix for backward input, output, forget, and cell gates
38+
*
39+
* RB[i] - R recurrence weight matrix for backward input, output, forget, and cell gates
40+
*
41+
* WBb[i] - W bias vectors for backward input, output, forget, and cell gates
42+
*
43+
* RBb[i] - R bias vectors for backward input, output, forget, and cell gates
44+
*
45+
* H - Hidden state
46+
* num_directions - 2 if direction == bidirectional else 1
47+
/////////////////////////////////////////////////////////////////
48+
// - Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) //
49+
/////////////////////////////////////////////////////////////////
50+
* ONNX specification end
51+
* Inputs:
52+
1. X, shape (T, N, input_size)
53+
- T is the time step
54+
- N is the number of the independent streams
55+
2. continue flag, shape (T, N)
56+
3. X_static (optional, (N, input_size))
57+
4. init_hidden_state, shape (1, N, num_output)
58+
59+
* Outputs:
60+
1. outputs, shape (T, N, num_output)
61+
2. final_hidden_state, shape (1, N, num_ouput)
62+
* Shapes of weights and bias:
63+
1. W: (num_ouptut, input_size)
64+
2. B: (num_output,)
65+
3. W_static (optional, (num_output, input_size))
66+
4. R: (num_output, num_output)
67+
*/
68+
template <typename Dtype>
69+
class SimpleRNNLayer : public RecurrentLayer<Dtype> {
70+
public:
71+
explicit SimpleRNNLayer(const LayerParameter& param)
72+
: RecurrentLayer<Dtype>(param) {}
73+
74+
virtual inline const char* type() const { return "SimpleRNN"; }
75+
76+
protected:
77+
virtual void FillUnrolledNet(NetParameter* net_param) const;
78+
virtual void RecurrentInputBlobNames(vector<string>* names) const;
79+
virtual void RecurrentOutputBlobNames(vector<string>* names) const;
80+
virtual void RecurrentInputShapes(vector<BlobShape>* shapes) const;
81+
virtual void OutputBlobNames(vector<string>* names) const;
82+
};
83+
84+
} // namespace caffe
85+
86+
#endif // CAFFE_SIMPLE_RNN_LAYER_HPP_
87+
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
#include <string>
2+
#include <vector>
3+
4+
#include "caffe/blob.hpp"
5+
#include "caffe/common.hpp"
6+
#include "caffe/filler.hpp"
7+
#include "caffe/layer.hpp"
8+
#include "caffe/layers/simple_rnn_layer.hpp"
9+
#include "caffe/util/math_functions.hpp"
10+
11+
namespace caffe {
12+
13+
template <typename Dtype>
14+
void SimpleRNNLayer<Dtype>::RecurrentInputBlobNames(vector<string>* names) const {
15+
names->resize(1);
16+
(*names)[0] = "h_0";
17+
}
18+
19+
template <typename Dtype>
20+
void SimpleRNNLayer<Dtype>::RecurrentOutputBlobNames(vector<string>* names) const {
21+
names->resize(1);
22+
(*names)[0] = "h_" + format_int(this->T_);
23+
}
24+
25+
template <typename Dtype>
26+
void SimpleRNNLayer<Dtype>::RecurrentInputShapes(vector<BlobShape>* shapes) const {
27+
const int num_output = this->layer_param_.recurrent_param().num_output();
28+
shapes->resize(1);
29+
(*shapes)[0].Clear();
30+
(*shapes)[0].add_dim(1); // a single timestep
31+
(*shapes)[0].add_dim(this->N_);
32+
(*shapes)[0].add_dim(num_output);
33+
}
34+
35+
template <typename Dtype>
36+
void SimpleRNNLayer<Dtype>::OutputBlobNames(vector<string>* names) const {
37+
names->resize(1);
38+
(*names)[0] = "h";
39+
}
40+
41+
template <typename Dtype>
42+
void SimpleRNNLayer<Dtype>::FillUnrolledNet(NetParameter* net_param) const {
43+
const int num_output = this->layer_param_.recurrent_param().num_output();
44+
45+
CHECK_GT(num_output, 0) << "num_output must be positive";
46+
const FillerParameter& weight_filler =
47+
this->layer_param_.recurrent_param().weight_filler();
48+
const FillerParameter& bias_filler =
49+
this->layer_param_.recurrent_param().bias_filler();
50+
51+
// Add generic LayerParameter's (without bottoms/tops) of layer types we'll
52+
// use to save redundant code.
53+
LayerParameter hidden_param;
54+
hidden_param.set_type("InnerProduct");
55+
hidden_param.mutable_inner_product_param()->set_num_output(num_output);
56+
hidden_param.mutable_inner_product_param()->set_bias_term(false);
57+
hidden_param.mutable_inner_product_param()->set_axis(2);
58+
hidden_param.mutable_inner_product_param()->
59+
mutable_weight_filler()->CopyFrom(weight_filler);
60+
61+
LayerParameter biased_hidden_param(hidden_param);
62+
biased_hidden_param.mutable_inner_product_param()->set_bias_term(true);
63+
biased_hidden_param.mutable_inner_product_param()->
64+
mutable_bias_filler()->CopyFrom(bias_filler);
65+
66+
LayerParameter sum_param;
67+
sum_param.set_type("Eltwise");
68+
sum_param.mutable_eltwise_param()->set_operation(
69+
EltwiseParameter_EltwiseOp_SUM);
70+
71+
LayerParameter scale_param;
72+
scale_param.set_type("Scale");
73+
scale_param.mutable_scale_param()->set_axis(0);
74+
75+
LayerParameter slice_param;
76+
slice_param.set_type("Slice");
77+
slice_param.mutable_slice_param()->set_axis(0);
78+
79+
// add activations for ONNX-RNN
80+
LayerParameter F_activation_param;
81+
if ((this->activations_.size() == 0) || (this->activations_[0] == "Tanh") || (this->activations_[0] == "tanh")) {
82+
F_activation_param.set_type("TanH");
83+
}
84+
else {
85+
// different name
86+
if ((this->activations_[0] == "Elu") || (this->activations_[0] == "elu")) {
87+
F_activation_param.set_type("ELU");
88+
if (this->activation_alpha_.size() > 0) {
89+
F_activation_param.mutable_elu_param()->set_alpha(this->activation_alpha_[0]);
90+
}
91+
}
92+
if (this->activations_[0] == "LeakyRelu") {
93+
F_activation_param.set_type("ReLU");
94+
if (this->activation_alpha_.size() > 0) {
95+
F_activation_param.mutable_relu_param()->set_negative_slope(this->activation_alpha_[0]);
96+
}
97+
}
98+
if ((this->activations_[0] == "Relu") || (this->activations_[0] == "relu")) {
99+
F_activation_param.set_type("ReLU");
100+
}
101+
if (this->activations_[0] == "ScaledTanh") {
102+
F_activation_param.set_type("ScaledTanH");
103+
if (this->activation_alpha_.size() > 0) {
104+
F_activation_param.mutable_scaled_tanh_param()->set_alpha(this->activation_alpha_[0]);
105+
}
106+
if (this->activation_beta_.size() > 0) {
107+
F_activation_param.mutable_scaled_tanh_param()->set_beta(this->activation_beta_[0]);
108+
}
109+
}
110+
if (this->activations_[0] == "ThresholdedRelu") {
111+
F_activation_param.set_type("ThresholdedReLU");
112+
if (this->activation_alpha_.size() > 0) {
113+
F_activation_param.mutable_thresholded_relu_param()->set_alpha(this->activation_alpha_[0]);
114+
}
115+
}
116+
// the same name
117+
if (this->activations_[0] == "HardSigmoid") {
118+
F_activation_param.set_type("HardSigmoid");
119+
if (this->activation_alpha_.size() > 0) {
120+
F_activation_param.mutable_hard_sigmoid_param()->set_alpha(this->activation_alpha_[0]);
121+
}
122+
if (this->activation_beta_.size() > 0) {
123+
F_activation_param.mutable_hard_sigmoid_param()->set_beta(this->activation_beta_[0]);
124+
}
125+
}
126+
if ((this->activations_[0] == "Sigmoid") || (this->activations_[0] == "sigmoid")) {
127+
F_activation_param.set_type("Sigmoid");
128+
}
129+
if ((this->activations_[0] == "Softsign") || (this->activations_[0] == "softsign")) {
130+
F_activation_param.set_type("Softsign");
131+
}
132+
}
133+
134+
vector<BlobShape> input_shapes;
135+
RecurrentInputShapes(&input_shapes);
136+
CHECK_EQ(1, input_shapes.size());
137+
138+
LayerParameter* input_layer_param = net_param->add_layer();
139+
input_layer_param->set_type("Input");
140+
InputParameter* input_param = input_layer_param->mutable_input_param();
141+
input_layer_param->add_top("h_0");
142+
input_param->add_shape()->CopyFrom(input_shapes[0]);
143+
144+
LayerParameter* cont_slice_param = net_param->add_layer();
145+
cont_slice_param->CopyFrom(slice_param);
146+
cont_slice_param->set_name("cont_slice");
147+
cont_slice_param->add_bottom("cont");
148+
cont_slice_param->mutable_slice_param()->set_axis(0);
149+
150+
// Add layer to transform all timesteps of x to the hidden state dimension.
151+
// W_xh_x = W_xh * x + b_h
152+
{
153+
LayerParameter* x_transform_param = net_param->add_layer();
154+
x_transform_param->CopyFrom(biased_hidden_param);
155+
x_transform_param->set_name("x_transform");
156+
x_transform_param->add_param()->set_name("W_xh");
157+
x_transform_param->add_param()->set_name("b_h");
158+
x_transform_param->add_bottom("x");
159+
x_transform_param->add_top("W_xh_x");
160+
x_transform_param->add_propagate_down(true);
161+
}
162+
163+
if (this->static_input_) {
164+
// Add layer to transform x_static to the hidden state dimension.
165+
// W_xh_x_static = W_xh_static * x_static
166+
LayerParameter* x_static_transform_param = net_param->add_layer();
167+
x_static_transform_param->CopyFrom(hidden_param);
168+
x_static_transform_param->mutable_inner_product_param()->set_axis(1);
169+
x_static_transform_param->set_name("W_xh_x_static");
170+
x_static_transform_param->add_param()->set_name("W_xh_static");
171+
x_static_transform_param->add_bottom("x_static");
172+
x_static_transform_param->add_top("W_xh_x_static_preshape");
173+
x_static_transform_param->add_propagate_down(true);
174+
175+
LayerParameter* reshape_param = net_param->add_layer();
176+
reshape_param->set_type("Reshape");
177+
BlobShape* new_shape =
178+
reshape_param->mutable_reshape_param()->mutable_shape();
179+
new_shape->add_dim(1); // One timestep.
180+
// Should infer this->N as the dimension so we can reshape on batch size.
181+
new_shape->add_dim(-1);
182+
new_shape->add_dim(
183+
x_static_transform_param->inner_product_param().num_output());
184+
reshape_param->set_name("W_xh_x_static_reshape");
185+
reshape_param->add_bottom("W_xh_x_static_preshape");
186+
reshape_param->add_top("W_xh_x_static");
187+
}
188+
189+
LayerParameter* x_slice_param = net_param->add_layer();
190+
x_slice_param->CopyFrom(slice_param);
191+
x_slice_param->set_name("W_xh_x_slice");
192+
x_slice_param->add_bottom("W_xh_x");
193+
194+
LayerParameter output_concat_layer;
195+
output_concat_layer.set_name("h_concat");
196+
output_concat_layer.set_type("Concat");
197+
output_concat_layer.add_top("h");
198+
output_concat_layer.mutable_concat_param()->set_axis(0);
199+
200+
for (int t = 1; t <= this->T_; ++t) {
201+
string tm1s = format_int(t - 1);
202+
string ts = format_int(t);
203+
204+
cont_slice_param->add_top("cont_" + ts);
205+
x_slice_param->add_top("W_xh_x_" + ts);
206+
207+
// Add layer to flush the hidden state when beginning a new sequence,
208+
// as indicated by cont_t.
209+
// h_conted_{t-1} := cont_t * h_{t-1}
210+
//
211+
// Normally, cont_t is binary (i.e., 0 or 1), so:
212+
// h_conted_{t-1} := h_{t-1} if cont_t == 1
213+
// 0 otherwise
214+
{
215+
LayerParameter* cont_h_param = net_param->add_layer();
216+
cont_h_param->CopyFrom(scale_param);
217+
cont_h_param->set_name("h_conted_" + tm1s);
218+
cont_h_param->add_bottom("h_" + tm1s);
219+
cont_h_param->add_bottom("cont_" + ts);
220+
cont_h_param->add_top("h_conted_" + tm1s);
221+
}
222+
223+
// Add layer to compute
224+
// W_hh_h_{t-1} := W_hh * h_conted_{t-1}
225+
{
226+
LayerParameter* w_param = net_param->add_layer();
227+
w_param->CopyFrom(hidden_param);
228+
w_param->set_name("W_hh_h_" + tm1s);
229+
w_param->add_param()->set_name("W_hh");
230+
w_param->add_bottom("h_conted_" + tm1s);
231+
w_param->add_top("W_hh_h_" + tm1s);
232+
w_param->mutable_inner_product_param()->set_axis(2);
233+
}
234+
235+
// Add layers to compute
236+
// h_t := \F_activation( W_hh * h_conted_{t-1} + W_xh * x_t + b_h )
237+
// = \F_activation( W_hh_h_{t-1} + W_xh_t )
238+
{
239+
LayerParameter* h_input_sum_param = net_param->add_layer();
240+
h_input_sum_param->CopyFrom(sum_param);
241+
h_input_sum_param->set_name("h_input_sum_" + ts);
242+
h_input_sum_param->add_bottom("W_hh_h_" + tm1s);
243+
h_input_sum_param->add_bottom("W_xh_x_" + ts);
244+
if (this->static_input_) {
245+
h_input_sum_param->add_bottom("W_xh_x_static");
246+
}
247+
h_input_sum_param->add_top("h_neuron_input_" + ts);
248+
}
249+
{
250+
LayerParameter* h_neuron_param = net_param->add_layer();
251+
h_neuron_param->CopyFrom(F_activation_param);
252+
h_neuron_param->set_name("h_neuron_" + ts);
253+
h_neuron_param->add_bottom("h_neuron_input_" + ts);
254+
h_neuron_param->add_top("h_" + ts);
255+
}
256+
output_concat_layer.add_bottom("h_" + ts);
257+
} // for (int t = 1; t <= this->T_; ++t)
258+
259+
net_param->add_layer()->CopyFrom(output_concat_layer);
260+
}
261+
262+
INSTANTIATE_CLASS(SimpleRNNLayer);
263+
REGISTER_LAYER_CLASS(SimpleRNN);
264+
265+
} // namespace caffe
266+

0 commit comments

Comments
 (0)