|
| 1 | +#include <string> |
| 2 | +#include <vector> |
| 3 | + |
| 4 | +#include "caffe/blob.hpp" |
| 5 | +#include "caffe/common.hpp" |
| 6 | +#include "caffe/filler.hpp" |
| 7 | +#include "caffe/layer.hpp" |
| 8 | +#include "caffe/layers/simple_rnn_layer.hpp" |
| 9 | +#include "caffe/util/math_functions.hpp" |
| 10 | + |
| 11 | +namespace caffe { |
| 12 | + |
| 13 | +template <typename Dtype> |
| 14 | +void SimpleRNNLayer<Dtype>::RecurrentInputBlobNames(vector<string>* names) const { |
| 15 | + names->resize(1); |
| 16 | + (*names)[0] = "h_0"; |
| 17 | +} |
| 18 | + |
| 19 | +template <typename Dtype> |
| 20 | +void SimpleRNNLayer<Dtype>::RecurrentOutputBlobNames(vector<string>* names) const { |
| 21 | + names->resize(1); |
| 22 | + (*names)[0] = "h_" + format_int(this->T_); |
| 23 | +} |
| 24 | + |
| 25 | +template <typename Dtype> |
| 26 | +void SimpleRNNLayer<Dtype>::RecurrentInputShapes(vector<BlobShape>* shapes) const { |
| 27 | + const int num_output = this->layer_param_.recurrent_param().num_output(); |
| 28 | + shapes->resize(1); |
| 29 | + (*shapes)[0].Clear(); |
| 30 | + (*shapes)[0].add_dim(1); // a single timestep |
| 31 | + (*shapes)[0].add_dim(this->N_); |
| 32 | + (*shapes)[0].add_dim(num_output); |
| 33 | +} |
| 34 | + |
| 35 | +template <typename Dtype> |
| 36 | +void SimpleRNNLayer<Dtype>::OutputBlobNames(vector<string>* names) const { |
| 37 | + names->resize(1); |
| 38 | + (*names)[0] = "h"; |
| 39 | +} |
| 40 | + |
| 41 | +template <typename Dtype> |
| 42 | +void SimpleRNNLayer<Dtype>::FillUnrolledNet(NetParameter* net_param) const { |
| 43 | + const int num_output = this->layer_param_.recurrent_param().num_output(); |
| 44 | + |
| 45 | + CHECK_GT(num_output, 0) << "num_output must be positive"; |
| 46 | + const FillerParameter& weight_filler = |
| 47 | + this->layer_param_.recurrent_param().weight_filler(); |
| 48 | + const FillerParameter& bias_filler = |
| 49 | + this->layer_param_.recurrent_param().bias_filler(); |
| 50 | + |
| 51 | + // Add generic LayerParameter's (without bottoms/tops) of layer types we'll |
| 52 | + // use to save redundant code. |
| 53 | + LayerParameter hidden_param; |
| 54 | + hidden_param.set_type("InnerProduct"); |
| 55 | + hidden_param.mutable_inner_product_param()->set_num_output(num_output); |
| 56 | + hidden_param.mutable_inner_product_param()->set_bias_term(false); |
| 57 | + hidden_param.mutable_inner_product_param()->set_axis(2); |
| 58 | + hidden_param.mutable_inner_product_param()-> |
| 59 | + mutable_weight_filler()->CopyFrom(weight_filler); |
| 60 | + |
| 61 | + LayerParameter biased_hidden_param(hidden_param); |
| 62 | + biased_hidden_param.mutable_inner_product_param()->set_bias_term(true); |
| 63 | + biased_hidden_param.mutable_inner_product_param()-> |
| 64 | + mutable_bias_filler()->CopyFrom(bias_filler); |
| 65 | + |
| 66 | + LayerParameter sum_param; |
| 67 | + sum_param.set_type("Eltwise"); |
| 68 | + sum_param.mutable_eltwise_param()->set_operation( |
| 69 | + EltwiseParameter_EltwiseOp_SUM); |
| 70 | + |
| 71 | + LayerParameter scale_param; |
| 72 | + scale_param.set_type("Scale"); |
| 73 | + scale_param.mutable_scale_param()->set_axis(0); |
| 74 | + |
| 75 | + LayerParameter slice_param; |
| 76 | + slice_param.set_type("Slice"); |
| 77 | + slice_param.mutable_slice_param()->set_axis(0); |
| 78 | + |
| 79 | + // add activations for ONNX-RNN |
| 80 | + LayerParameter F_activation_param; |
| 81 | + if ((this->activations_.size() == 0) || (this->activations_[0] == "Tanh") || (this->activations_[0] == "tanh")) { |
| 82 | + F_activation_param.set_type("TanH"); |
| 83 | + } |
| 84 | + else { |
| 85 | + // different name |
| 86 | + if ((this->activations_[0] == "Elu") || (this->activations_[0] == "elu")) { |
| 87 | + F_activation_param.set_type("ELU"); |
| 88 | + if (this->activation_alpha_.size() > 0) { |
| 89 | + F_activation_param.mutable_elu_param()->set_alpha(this->activation_alpha_[0]); |
| 90 | + } |
| 91 | + } |
| 92 | + if (this->activations_[0] == "LeakyRelu") { |
| 93 | + F_activation_param.set_type("ReLU"); |
| 94 | + if (this->activation_alpha_.size() > 0) { |
| 95 | + F_activation_param.mutable_relu_param()->set_negative_slope(this->activation_alpha_[0]); |
| 96 | + } |
| 97 | + } |
| 98 | + if ((this->activations_[0] == "Relu") || (this->activations_[0] == "relu")) { |
| 99 | + F_activation_param.set_type("ReLU"); |
| 100 | + } |
| 101 | + if (this->activations_[0] == "ScaledTanh") { |
| 102 | + F_activation_param.set_type("ScaledTanH"); |
| 103 | + if (this->activation_alpha_.size() > 0) { |
| 104 | + F_activation_param.mutable_scaled_tanh_param()->set_alpha(this->activation_alpha_[0]); |
| 105 | + } |
| 106 | + if (this->activation_beta_.size() > 0) { |
| 107 | + F_activation_param.mutable_scaled_tanh_param()->set_beta(this->activation_beta_[0]); |
| 108 | + } |
| 109 | + } |
| 110 | + if (this->activations_[0] == "ThresholdedRelu") { |
| 111 | + F_activation_param.set_type("ThresholdedReLU"); |
| 112 | + if (this->activation_alpha_.size() > 0) { |
| 113 | + F_activation_param.mutable_thresholded_relu_param()->set_alpha(this->activation_alpha_[0]); |
| 114 | + } |
| 115 | + } |
| 116 | + // the same name |
| 117 | + if (this->activations_[0] == "HardSigmoid") { |
| 118 | + F_activation_param.set_type("HardSigmoid"); |
| 119 | + if (this->activation_alpha_.size() > 0) { |
| 120 | + F_activation_param.mutable_hard_sigmoid_param()->set_alpha(this->activation_alpha_[0]); |
| 121 | + } |
| 122 | + if (this->activation_beta_.size() > 0) { |
| 123 | + F_activation_param.mutable_hard_sigmoid_param()->set_beta(this->activation_beta_[0]); |
| 124 | + } |
| 125 | + } |
| 126 | + if ((this->activations_[0] == "Sigmoid") || (this->activations_[0] == "sigmoid")) { |
| 127 | + F_activation_param.set_type("Sigmoid"); |
| 128 | + } |
| 129 | + if ((this->activations_[0] == "Softsign") || (this->activations_[0] == "softsign")) { |
| 130 | + F_activation_param.set_type("Softsign"); |
| 131 | + } |
| 132 | + } |
| 133 | + |
| 134 | + vector<BlobShape> input_shapes; |
| 135 | + RecurrentInputShapes(&input_shapes); |
| 136 | + CHECK_EQ(1, input_shapes.size()); |
| 137 | + |
| 138 | + LayerParameter* input_layer_param = net_param->add_layer(); |
| 139 | + input_layer_param->set_type("Input"); |
| 140 | + InputParameter* input_param = input_layer_param->mutable_input_param(); |
| 141 | + input_layer_param->add_top("h_0"); |
| 142 | + input_param->add_shape()->CopyFrom(input_shapes[0]); |
| 143 | + |
| 144 | + LayerParameter* cont_slice_param = net_param->add_layer(); |
| 145 | + cont_slice_param->CopyFrom(slice_param); |
| 146 | + cont_slice_param->set_name("cont_slice"); |
| 147 | + cont_slice_param->add_bottom("cont"); |
| 148 | + cont_slice_param->mutable_slice_param()->set_axis(0); |
| 149 | + |
| 150 | + // Add layer to transform all timesteps of x to the hidden state dimension. |
| 151 | + // W_xh_x = W_xh * x + b_h |
| 152 | + { |
| 153 | + LayerParameter* x_transform_param = net_param->add_layer(); |
| 154 | + x_transform_param->CopyFrom(biased_hidden_param); |
| 155 | + x_transform_param->set_name("x_transform"); |
| 156 | + x_transform_param->add_param()->set_name("W_xh"); |
| 157 | + x_transform_param->add_param()->set_name("b_h"); |
| 158 | + x_transform_param->add_bottom("x"); |
| 159 | + x_transform_param->add_top("W_xh_x"); |
| 160 | + x_transform_param->add_propagate_down(true); |
| 161 | + } |
| 162 | + |
| 163 | + if (this->static_input_) { |
| 164 | + // Add layer to transform x_static to the hidden state dimension. |
| 165 | + // W_xh_x_static = W_xh_static * x_static |
| 166 | + LayerParameter* x_static_transform_param = net_param->add_layer(); |
| 167 | + x_static_transform_param->CopyFrom(hidden_param); |
| 168 | + x_static_transform_param->mutable_inner_product_param()->set_axis(1); |
| 169 | + x_static_transform_param->set_name("W_xh_x_static"); |
| 170 | + x_static_transform_param->add_param()->set_name("W_xh_static"); |
| 171 | + x_static_transform_param->add_bottom("x_static"); |
| 172 | + x_static_transform_param->add_top("W_xh_x_static_preshape"); |
| 173 | + x_static_transform_param->add_propagate_down(true); |
| 174 | + |
| 175 | + LayerParameter* reshape_param = net_param->add_layer(); |
| 176 | + reshape_param->set_type("Reshape"); |
| 177 | + BlobShape* new_shape = |
| 178 | + reshape_param->mutable_reshape_param()->mutable_shape(); |
| 179 | + new_shape->add_dim(1); // One timestep. |
| 180 | + // Should infer this->N as the dimension so we can reshape on batch size. |
| 181 | + new_shape->add_dim(-1); |
| 182 | + new_shape->add_dim( |
| 183 | + x_static_transform_param->inner_product_param().num_output()); |
| 184 | + reshape_param->set_name("W_xh_x_static_reshape"); |
| 185 | + reshape_param->add_bottom("W_xh_x_static_preshape"); |
| 186 | + reshape_param->add_top("W_xh_x_static"); |
| 187 | + } |
| 188 | + |
| 189 | + LayerParameter* x_slice_param = net_param->add_layer(); |
| 190 | + x_slice_param->CopyFrom(slice_param); |
| 191 | + x_slice_param->set_name("W_xh_x_slice"); |
| 192 | + x_slice_param->add_bottom("W_xh_x"); |
| 193 | + |
| 194 | + LayerParameter output_concat_layer; |
| 195 | + output_concat_layer.set_name("h_concat"); |
| 196 | + output_concat_layer.set_type("Concat"); |
| 197 | + output_concat_layer.add_top("h"); |
| 198 | + output_concat_layer.mutable_concat_param()->set_axis(0); |
| 199 | + |
| 200 | + for (int t = 1; t <= this->T_; ++t) { |
| 201 | + string tm1s = format_int(t - 1); |
| 202 | + string ts = format_int(t); |
| 203 | + |
| 204 | + cont_slice_param->add_top("cont_" + ts); |
| 205 | + x_slice_param->add_top("W_xh_x_" + ts); |
| 206 | + |
| 207 | + // Add layer to flush the hidden state when beginning a new sequence, |
| 208 | + // as indicated by cont_t. |
| 209 | + // h_conted_{t-1} := cont_t * h_{t-1} |
| 210 | + // |
| 211 | + // Normally, cont_t is binary (i.e., 0 or 1), so: |
| 212 | + // h_conted_{t-1} := h_{t-1} if cont_t == 1 |
| 213 | + // 0 otherwise |
| 214 | + { |
| 215 | + LayerParameter* cont_h_param = net_param->add_layer(); |
| 216 | + cont_h_param->CopyFrom(scale_param); |
| 217 | + cont_h_param->set_name("h_conted_" + tm1s); |
| 218 | + cont_h_param->add_bottom("h_" + tm1s); |
| 219 | + cont_h_param->add_bottom("cont_" + ts); |
| 220 | + cont_h_param->add_top("h_conted_" + tm1s); |
| 221 | + } |
| 222 | + |
| 223 | + // Add layer to compute |
| 224 | + // W_hh_h_{t-1} := W_hh * h_conted_{t-1} |
| 225 | + { |
| 226 | + LayerParameter* w_param = net_param->add_layer(); |
| 227 | + w_param->CopyFrom(hidden_param); |
| 228 | + w_param->set_name("W_hh_h_" + tm1s); |
| 229 | + w_param->add_param()->set_name("W_hh"); |
| 230 | + w_param->add_bottom("h_conted_" + tm1s); |
| 231 | + w_param->add_top("W_hh_h_" + tm1s); |
| 232 | + w_param->mutable_inner_product_param()->set_axis(2); |
| 233 | + } |
| 234 | + |
| 235 | + // Add layers to compute |
| 236 | + // h_t := \F_activation( W_hh * h_conted_{t-1} + W_xh * x_t + b_h ) |
| 237 | + // = \F_activation( W_hh_h_{t-1} + W_xh_t ) |
| 238 | + { |
| 239 | + LayerParameter* h_input_sum_param = net_param->add_layer(); |
| 240 | + h_input_sum_param->CopyFrom(sum_param); |
| 241 | + h_input_sum_param->set_name("h_input_sum_" + ts); |
| 242 | + h_input_sum_param->add_bottom("W_hh_h_" + tm1s); |
| 243 | + h_input_sum_param->add_bottom("W_xh_x_" + ts); |
| 244 | + if (this->static_input_) { |
| 245 | + h_input_sum_param->add_bottom("W_xh_x_static"); |
| 246 | + } |
| 247 | + h_input_sum_param->add_top("h_neuron_input_" + ts); |
| 248 | + } |
| 249 | + { |
| 250 | + LayerParameter* h_neuron_param = net_param->add_layer(); |
| 251 | + h_neuron_param->CopyFrom(F_activation_param); |
| 252 | + h_neuron_param->set_name("h_neuron_" + ts); |
| 253 | + h_neuron_param->add_bottom("h_neuron_input_" + ts); |
| 254 | + h_neuron_param->add_top("h_" + ts); |
| 255 | + } |
| 256 | + output_concat_layer.add_bottom("h_" + ts); |
| 257 | + } // for (int t = 1; t <= this->T_; ++t) |
| 258 | + |
| 259 | + net_param->add_layer()->CopyFrom(output_concat_layer); |
| 260 | +} |
| 261 | + |
| 262 | +INSTANTIATE_CLASS(SimpleRNNLayer); |
| 263 | +REGISTER_LAYER_CLASS(SimpleRNN); |
| 264 | + |
| 265 | +} // namespace caffe |
| 266 | + |
0 commit comments