@@ -39,10 +39,12 @@ void RNNv2Layer<Dtype>::OutputBlobNames(vector<string> *names) const {
3939}
4040
4141template <typename Dtype>
42- void RNNv2Layer<Dtype>::FillUnrolledNet(
43- NetParameter *net_param, const string x_name, const string cont_name,
44- vector<string> output_names, vector<string> recur_name_prefix,
45- const string &layer_name_prefix) {
42+ void RNNv2Layer<Dtype>::FillUnrolledNet(NetParameter *net_param,
43+ const string x_name,
44+ const string cont_name,
45+ vector<string> output_names,
46+ vector<string> recur_name_prefix,
47+ const string &layer_name_prefix) {
4648 const int hidden_size = this ->layer_param_ .rnn_v2_param ().hidden_size ();
4749 CHECK_GT (hidden_size, 0 ) << " hidden_size must be positive" ;
4850
@@ -437,6 +439,7 @@ void RNNv2Layer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
437439 {
438440 LayerParameter slice_param;
439441 slice_param.set_type (" Slice" );
442+ slice_param.mutable_slice_param ()->set_axis (0 );
440443 LayerParameter *recur_input_copy_param;
441444 for (int i = 0 ; i < num_recur_blobs; ++i) {
442445 recur_input_copy_param = net_param.add_layer ();
@@ -517,13 +520,13 @@ void RNNv2Layer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
517520 string merge_mode = this ->layer_param_ .rnn_v2_param ().merge_mode ();
518521 // https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/keras/layers/wrappers.py#L506
519522 if (merge_mode == " concat" ) {
520- LayerParameter output_concat_layer ;
521- output_concat_layer. set_type (" Concat" );
522- output_concat_layer. add_bottom (" fw_" + output_names[0 ]);
523- output_concat_layer. add_bottom (" bw_rev_" + output_names[0 ]);
524- output_concat_layer. add_top (output_names[0 ]);
525- output_concat_layer. set_name (output_names[0 ]);
526- output_concat_layer. mutable_concat_param ()->set_axis (-1 );
523+ LayerParameter *outputs_layer = net_param. add_layer () ;
524+ outputs_layer-> set_type (" Concat" );
525+ outputs_layer-> add_bottom (" fw_" + output_names[0 ]);
526+ outputs_layer-> add_bottom (" bw_rev_" + output_names[0 ]);
527+ outputs_layer-> add_top (output_names[0 ]);
528+ outputs_layer-> set_name (output_names[0 ]);
529+ outputs_layer-> mutable_concat_param ()->set_axis (-1 );
527530 } else {
528531 LOG (ERROR)
529532 << " The value of merge_mode of RNNv2 layer is not supported: "
@@ -585,6 +588,7 @@ void RNNv2Layer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
585588 }
586589 }
587590 } else {
591+ // bidirectional
588592 const int hidden_size = this ->layer_param_ .rnn_v2_param ().hidden_size ();
589593 const int input_size = bottom[0 ]->shape (2 );
590594
@@ -601,89 +605,86 @@ void RNNv2Layer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
601605 }
602606}
603607
604- template <typename Dtype>
605- void RNNv2Layer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
606- const vector<Blob<Dtype> *> &top) {
607- CHECK_GE (bottom[0 ]->num_axes (), 2 )
608- << " bottom[0] must have at least 2 axes -- (#timesteps, #streams, ...)" ;
609- CHECK_EQ (T_, bottom[0 ]->shape (0 )) << " input number of timesteps changed" ;
610- N_ = bottom[0 ]->shape (1 );
611- CHECK_EQ (bottom[1 ]->num_axes (), 2 )
612- << " bottom[1] must have exactly 2 axes -- (#timesteps, #streams)" ;
613- CHECK_EQ (T_, bottom[1 ]->shape (0 ));
614- CHECK_EQ (N_, bottom[1 ]->shape (1 ));
615- x_input_blob_->ReshapeLike (*bottom[0 ]);
616- vector<int > cont_shape = bottom[1 ]->shape ();
617- cont_input_blob_->Reshape (cont_shape);
618- vector<BlobShape> recur_input_shapes;
619- RecurrentInputShapes (&recur_input_shapes);
620- CHECK_EQ (recur_input_shapes.size (), recur_input_blobs_.size ());
621- for (int i = 0 ; i < recur_input_shapes.size (); ++i) {
622- recur_input_blobs_[i]->Reshape (recur_input_shapes[i]);
623- }
624- unrolled_net_->Reshape ();
625- x_input_blob_->ShareData (*bottom[0 ]);
626- x_input_blob_->ShareDiff (*bottom[0 ]);
627- cont_input_blob_->ShareData (*bottom[1 ]);
628- const int bottom_offset = 2 ;
629- for (int i = bottom_offset, j = 0 ; i < bottom.size (); ++i, ++j) {
630- CHECK (recur_input_blobs_[j]->shape () == bottom[i]->shape ())
631- << " shape mismatch - recur_input_blobs_[" << j
632- << " ]: " << recur_input_blobs_[j]->shape_string () << " vs. bottom[" << i
633- << " ]: " << bottom[i]->shape_string ();
634- recur_input_blobs_[j]->ShareData (*bottom[i]);
635- }
636- for (int i = 0 ; i < output_blobs_.size (); ++i) {
637- top[i]->ReshapeLike (*output_blobs_[i]);
638- top[i]->ShareData (*output_blobs_[i]);
639- top[i]->ShareDiff (*output_blobs_[i]);
640- }
641- const int top_offset = output_blobs_.size ();
642- for (int i = top_offset, j = 0 ; i < top.size (); ++i, ++j) {
643- top[i]->ReshapeLike (*recur_output_blobs_[j]);
608+ template <typename Dtype>
609+ void RNNv2Layer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
610+ const vector<Blob<Dtype> *> &top) {
611+ CHECK_GE (bottom[0 ]->num_axes (), 2 )
612+ << " bottom[0] must have at least 2 axes -- (#timesteps, #streams, ...)" ;
613+ CHECK_EQ (T_, bottom[0 ]->shape (0 )) << " input number of timesteps changed" ;
614+ N_ = bottom[0 ]->shape (1 );
615+ CHECK_EQ (bottom[1 ]->num_axes (), 2 )
616+ << " bottom[1] must have exactly 2 axes -- (#timesteps, #streams)" ;
617+ CHECK_EQ (T_, bottom[1 ]->shape (0 ));
618+ CHECK_EQ (N_, bottom[1 ]->shape (1 ));
619+ x_input_blob_->ReshapeLike (*bottom[0 ]);
620+ vector<int > cont_shape = bottom[1 ]->shape ();
621+ cont_input_blob_->Reshape (cont_shape);
622+ vector<BlobShape> recur_input_shapes;
623+ RecurrentInputShapes (&recur_input_shapes);
624+ CHECK_EQ (recur_input_shapes.size (), recur_input_blobs_.size ());
625+ for (int i = 0 ; i < recur_input_shapes.size (); ++i) {
626+ recur_input_blobs_[i]->Reshape (recur_input_shapes[i]);
627+ }
628+ unrolled_net_->Reshape ();
629+ x_input_blob_->ShareData (*bottom[0 ]);
630+ x_input_blob_->ShareDiff (*bottom[0 ]);
631+ cont_input_blob_->ShareData (*bottom[1 ]);
632+ const int bottom_offset = 2 ;
633+ for (int i = bottom_offset, j = 0 ; i < bottom.size (); ++i, ++j) {
634+ CHECK (recur_input_blobs_[j]->shape () == bottom[i]->shape ())
635+ << " shape mismatch - recur_input_blobs_[" << j
636+ << " ]: " << recur_input_blobs_[j]->shape_string () << " vs. bottom["
637+ << i << " ]: " << bottom[i]->shape_string ();
638+ recur_input_blobs_[j]->ShareData (*bottom[i]);
639+ }
640+ for (int i = 0 ; i < output_blobs_.size (); ++i) {
641+ top[i]->ReshapeLike (*output_blobs_[i]);
642+ top[i]->ShareData (*output_blobs_[i]);
643+ top[i]->ShareDiff (*output_blobs_[i]);
644+ }
645+ const int top_offset = output_blobs_.size ();
646+ for (int i = top_offset, j = 0 ; i < top.size (); ++i, ++j) {
647+ top[i]->ReshapeLike (*recur_output_blobs_[j]);
648+ }
644649 }
645- }
646650
647- template <typename Dtype> void RNNv2Layer<Dtype>::Reset() {
648- // "Reset" the hidden state of the net by zeroing out all recurrent outputs.
649- for (int i = 0 ; i < recur_output_blobs_.size (); ++i) {
650- caffe_set (recur_output_blobs_[i]->count (), Dtype (0 ),
651- recur_output_blobs_[i]->mutable_cpu_data ());
652- }
653- }
654651
655- template <typename Dtype>
656- void RNNv2Layer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
657- const vector<Blob<Dtype> *> &top) {
658- // for bidirectional rnn, split the weights into two parts: forward
659- // and reverse
660- if (direction_ == " bidirectional" ) {
661- const int blobs_size = this ->blobs_ .size ();
662- for (int i = 0 ; i < blobs_size; ++i) {
663- if (this ->blobs_ [i]->count () % 2 != 0 )
664- LOG (FATAL) << " The total number of the weight blobs[" << i
665- << " ] cannot be divided by 2" ;
666- // weight blob for forward
667- Dtype *f = unrolled_net_->params ()[i]->mutable_cpu_data ();
668- // weight blob for backward
669- Dtype *b = unrolled_net_->params ()[i + blobs_size]->mutable_cpu_data ();
670- std::memcpy (f, this ->blobs_ [i]->cpu_data (),
671- sizeof (Dtype) * this ->blobs_ [i]->count () / 2 );
672- std::memcpy (b, this ->blobs_ [i]->cpu_data () + this ->blobs_ [i]->count () / 2 ,
673- sizeof (Dtype) * this ->blobs_ [i]->count () / 2 );
652+ template <typename Dtype>
653+ void RNNv2Layer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
654+ const vector<Blob<Dtype> *> &top) {
655+ // for bidirectional rnn, split the weights into two parts: forward
656+ // and reverse
657+ if (direction_ == " bidirectional" ) {
658+ const int blobs_size = this ->blobs_ .size ();
659+ CHECK_EQ (unrolled_net_->params ().size () % 2 , 0 );
660+ const int params_size = unrolled_net_->params ().size () / 2 ;
661+
662+ for (int i = 0 ; i < blobs_size; ++i) {
663+ if (this ->blobs_ [i]->count () % 2 != 0 )
664+ LOG (FATAL) << " The total number of the weight blobs[" << i
665+ << " ] cannot be divided by 2" ;
666+ // weight blob for forward
667+ void *f = unrolled_net_->params ()[i]->mutable_cpu_data ();
668+ // weight blob for backward
669+ void *b = unrolled_net_->params ()[i + params_size]->mutable_cpu_data ();
670+ std::memcpy (f, this ->blobs_ [i]->cpu_data (),
671+ sizeof (Dtype) * this ->blobs_ [i]->count () / 2 );
672+ std::memcpy (b,
673+ this ->blobs_ [i]->cpu_data () + this ->blobs_ [i]->count () / 2 ,
674+ sizeof (Dtype) * this ->blobs_ [i]->count () / 2 );
675+ }
674676 }
675- }
676677
677- DCHECK_EQ (recur_input_blobs_.size (), recur_output_blobs_.size ());
678+ DCHECK_EQ (recur_input_blobs_.size (), recur_output_blobs_.size ());
678679
679- unrolled_net_->ForwardTo (unrolled_net_->layers ().size () - 1 );
680+ unrolled_net_->ForwardTo (unrolled_net_->layers ().size () - 1 );
680681
681- const int top_offset = output_blobs_.size ();
682- for (int i = top_offset, j = 0 ; i < top.size (); ++i, ++j) {
683- top[i]->ShareData (*recur_output_blobs_[j]);
682+ const int top_offset = output_blobs_.size ();
683+ for (int i = top_offset, j = 0 ; i < top.size (); ++i, ++j) {
684+ top[i]->ShareData (*recur_output_blobs_[j]);
685+ }
684686 }
685- }
686687
687- INSTANTIATE_CLASS (RNNv2Layer);
688- REGISTER_LAYER_CLASS (RNNv2);
688+ INSTANTIATE_CLASS (RNNv2Layer);
689+ REGISTER_LAYER_CLASS (RNNv2);
689690} // namespace caffe
0 commit comments