@@ -605,86 +605,86 @@ void RNNv2Layer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
605605 }
606606}
607607
608- template <typename Dtype>
609- void RNNv2Layer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
608+ template <typename Dtype>
609+ void RNNv2Layer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
610610 const vector<Blob<Dtype> *> &top) {
611- CHECK_GE (bottom[0 ]->num_axes (), 2 )
612- << " bottom[0] must have at least 2 axes -- (#timesteps, #streams, ...)" ;
613- CHECK_EQ (T_, bottom[0 ]->shape (0 )) << " input number of timesteps changed" ;
614- N_ = bottom[0 ]->shape (1 );
615- CHECK_EQ (bottom[1 ]->num_axes (), 2 )
616- << " bottom[1] must have exactly 2 axes -- (#timesteps, #streams)" ;
617- CHECK_EQ (T_, bottom[1 ]->shape (0 ));
618- CHECK_EQ (N_, bottom[1 ]->shape (1 ));
619- x_input_blob_->ReshapeLike (*bottom[0 ]);
620- vector<int > cont_shape = bottom[1 ]->shape ();
621- cont_input_blob_->Reshape (cont_shape);
622- vector<BlobShape> recur_input_shapes;
623- RecurrentInputShapes (&recur_input_shapes);
624- CHECK_EQ (recur_input_shapes.size (), recur_input_blobs_.size ());
625- for (int i = 0 ; i < recur_input_shapes.size (); ++i) {
626- recur_input_blobs_[i]->Reshape (recur_input_shapes[i]);
627- }
628- unrolled_net_->Reshape ();
629- x_input_blob_->ShareData (*bottom[0 ]);
630- x_input_blob_->ShareDiff (*bottom[0 ]);
631- cont_input_blob_->ShareData (*bottom[1 ]);
632- const int bottom_offset = 2 ;
633- for (int i = bottom_offset, j = 0 ; i < bottom.size (); ++i, ++j) {
634- CHECK (recur_input_blobs_[j]->shape () == bottom[i]->shape ())
635- << " shape mismatch - recur_input_blobs_[" << j
636- << " ]: " << recur_input_blobs_[j]->shape_string () << " vs. bottom["
637- << i << " ]: " << bottom[i]->shape_string ();
638- recur_input_blobs_[j]->ShareData (*bottom[i]);
639- }
640- for (int i = 0 ; i < output_blobs_.size (); ++i) {
641- top[i]->ReshapeLike (*output_blobs_[i]);
642- top[i]->ShareData (*output_blobs_[i]);
643- top[i]->ShareDiff (*output_blobs_[i]);
644- }
645- const int top_offset = output_blobs_.size ();
646- for (int i = top_offset, j = 0 ; i < top.size (); ++i, ++j) {
647- top[i]->ReshapeLike (*recur_output_blobs_[j]);
648- }
611+ CHECK_GE (bottom[0 ]->num_axes (), 2 )
612+ << " bottom[0] must have at least 2 axes -- (#timesteps, #streams, ...)" ;
613+ CHECK_EQ (T_, bottom[0 ]->shape (0 )) << " input number of timesteps changed" ;
614+ N_ = bottom[0 ]->shape (1 );
615+ CHECK_EQ (bottom[1 ]->num_axes (), 2 )
616+ << " bottom[1] must have exactly 2 axes -- (#timesteps, #streams)" ;
617+ CHECK_EQ (T_, bottom[1 ]->shape (0 ));
618+ CHECK_EQ (N_, bottom[1 ]->shape (1 ));
619+ x_input_blob_->ReshapeLike (*bottom[0 ]);
620+ vector<int > cont_shape = bottom[1 ]->shape ();
621+ cont_input_blob_->Reshape (cont_shape);
622+ vector<BlobShape> recur_input_shapes;
623+ RecurrentInputShapes (&recur_input_shapes);
624+ CHECK_EQ (recur_input_shapes.size (), recur_input_blobs_.size ());
625+ for (int i = 0 ; i < recur_input_shapes.size (); ++i) {
626+ recur_input_blobs_[i]->Reshape (recur_input_shapes[i]);
627+ }
628+ unrolled_net_->Reshape ();
629+ x_input_blob_->ShareData (*bottom[0 ]);
630+ x_input_blob_->ShareDiff (*bottom[0 ]);
631+ cont_input_blob_->ShareData (*bottom[1 ]);
632+ const int bottom_offset = 2 ;
633+ for (int i = bottom_offset, j = 0 ; i < bottom.size (); ++i, ++j) {
634+ CHECK (recur_input_blobs_[j]->shape () == bottom[i]->shape ())
635+ << " shape mismatch - recur_input_blobs_[" << j
636+ << " ]: " << recur_input_blobs_[j]->shape_string () << " vs. bottom["
637+ << i << " ]: " << bottom[i]->shape_string ();
638+ recur_input_blobs_[j]->ShareData (*bottom[i]);
639+ }
640+ for (int i = 0 ; i < output_blobs_.size (); ++i) {
641+ top[i]->ReshapeLike (*output_blobs_[i]);
642+ top[i]->ShareData (*output_blobs_[i]);
643+ top[i]->ShareDiff (*output_blobs_[i]);
649644 }
645+ const int top_offset = output_blobs_.size ();
646+ for (int i = top_offset, j = 0 ; i < top.size (); ++i, ++j) {
647+ top[i]->ReshapeLike (*recur_output_blobs_[j]);
648+ }
649+ }
650650
651651
652- template <typename Dtype>
653- void RNNv2Layer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
654- const vector<Blob<Dtype> *> &top) {
655- // for bidirectional rnn, split the weights into two parts: forward
656- // and reverse
657- if (direction_ == " bidirectional" ) {
658- const int blobs_size = this ->blobs_ .size ();
659- CHECK_EQ (unrolled_net_->params ().size () % 2 , 0 );
660- const int params_size = unrolled_net_->params ().size () / 2 ;
661-
662- for (int i = 0 ; i < blobs_size; ++i) {
663- if (this ->blobs_ [i]->count () % 2 != 0 )
664- LOG (FATAL) << " The total number of the weight blobs[" << i
665- << " ] cannot be divided by 2" ;
666- // weight blob for forward
667- void *f = unrolled_net_->params ()[i]->mutable_cpu_data ();
668- // weight blob for backward
669- void *b = unrolled_net_->params ()[i + params_size]->mutable_cpu_data ();
670- std::memcpy (f, this ->blobs_ [i]->cpu_data (),
671- sizeof (Dtype) * this ->blobs_ [i]->count () / 2 );
672- std::memcpy (b,
673- this ->blobs_ [i]->cpu_data () + this ->blobs_ [i]->count () / 2 ,
674- sizeof (Dtype) * this ->blobs_ [i]->count () / 2 );
675- }
652+ template <typename Dtype>
653+ void RNNv2Layer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
654+ const vector<Blob<Dtype> *> &top) {
655+ // for bidirectional rnn, split the weights into two parts: forward
656+ // and reverse
657+ if (direction_ == " bidirectional" ) {
658+ const int blobs_size = this ->blobs_ .size ();
659+ CHECK_EQ (unrolled_net_->params ().size () % 2 , 0 );
660+ const int params_size = unrolled_net_->params ().size () / 2 ;
661+
662+ for (int i = 0 ; i < blobs_size; ++i) {
663+ if (this ->blobs_ [i]->count () % 2 != 0 )
664+ LOG (FATAL) << " The total number of the weight blobs[" << i
665+ << " ] cannot be divided by 2" ;
666+ // weight blob for forward
667+ void *f = unrolled_net_->params ()[i]->mutable_cpu_data ();
668+ // weight blob for backward
669+ void *b = unrolled_net_->params ()[i + params_size]->mutable_cpu_data ();
670+ std::memcpy (f, this ->blobs_ [i]->cpu_data (),
671+ sizeof (Dtype) * this ->blobs_ [i]->count () / 2 );
672+ std::memcpy (b,
673+ this ->blobs_ [i]->cpu_data () + this ->blobs_ [i]->count () / 2 ,
674+ sizeof (Dtype) * this ->blobs_ [i]->count () / 2 );
676675 }
676+ }
677677
678- DCHECK_EQ (recur_input_blobs_.size (), recur_output_blobs_.size ());
678+ DCHECK_EQ (recur_input_blobs_.size (), recur_output_blobs_.size ());
679679
680- unrolled_net_->ForwardTo (unrolled_net_->layers ().size () - 1 );
680+ unrolled_net_->ForwardTo (unrolled_net_->layers ().size () - 1 );
681681
682- const int top_offset = output_blobs_.size ();
683- for (int i = top_offset, j = 0 ; i < top.size (); ++i, ++j) {
684- top[i]->ShareData (*recur_output_blobs_[j]);
685- }
682+ const int top_offset = output_blobs_.size ();
683+ for (int i = top_offset, j = 0 ; i < top.size (); ++i, ++j) {
684+ top[i]->ShareData (*recur_output_blobs_[j]);
686685 }
686+ }
687687
688- INSTANTIATE_CLASS (RNNv2Layer);
689- REGISTER_LAYER_CLASS (RNNv2);
688+ INSTANTIATE_CLASS (RNNv2Layer);
689+ REGISTER_LAYER_CLASS (RNNv2);
690690} // namespace caffe
0 commit comments