Skip to content

Commit 3936762

Browse files
committed
workable revision for RNNv2(bidirectional)
1 parent dfedd21 commit 3936762

File tree

2 files changed

+88
-88
lines changed

2 files changed

+88
-88
lines changed

include/caffe/layers/rnn_v2_layer.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ template <typename Dtype> class RNNv2Layer : public Layer<Dtype> {
2121
const vector<Blob<Dtype> *> &top);
2222
virtual void Reshape(const vector<Blob<Dtype> *> &bottom,
2323
const vector<Blob<Dtype> *> &top);
24-
virtual void Reset();
2524

2625
virtual inline const char *type() const { return "RNNv2"; }
2726
virtual inline int MinBottomBlobs() const {

src/caffe/layers/rnn_v2_layer.cpp

Lines changed: 88 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,12 @@ void RNNv2Layer<Dtype>::OutputBlobNames(vector<string> *names) const {
3939
}
4040

4141
template <typename Dtype>
42-
void RNNv2Layer<Dtype>::FillUnrolledNet(
43-
NetParameter *net_param, const string x_name, const string cont_name,
44-
vector<string> output_names, vector<string> recur_name_prefix,
45-
const string &layer_name_prefix) {
42+
void RNNv2Layer<Dtype>::FillUnrolledNet(NetParameter *net_param,
43+
const string x_name,
44+
const string cont_name,
45+
vector<string> output_names,
46+
vector<string> recur_name_prefix,
47+
const string &layer_name_prefix) {
4648
const int hidden_size = this->layer_param_.rnn_v2_param().hidden_size();
4749
CHECK_GT(hidden_size, 0) << "hidden_size must be positive";
4850

@@ -437,6 +439,7 @@ void RNNv2Layer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
437439
{
438440
LayerParameter slice_param;
439441
slice_param.set_type("Slice");
442+
slice_param.mutable_slice_param()->set_axis(0);
440443
LayerParameter *recur_input_copy_param;
441444
for (int i = 0; i < num_recur_blobs; ++i) {
442445
recur_input_copy_param = net_param.add_layer();
@@ -517,13 +520,13 @@ void RNNv2Layer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
517520
string merge_mode = this->layer_param_.rnn_v2_param().merge_mode();
518521
// https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/keras/layers/wrappers.py#L506
519522
if (merge_mode == "concat") {
520-
LayerParameter output_concat_layer;
521-
output_concat_layer.set_type("Concat");
522-
output_concat_layer.add_bottom("fw_" + output_names[0]);
523-
output_concat_layer.add_bottom("bw_rev_" + output_names[0]);
524-
output_concat_layer.add_top(output_names[0]);
525-
output_concat_layer.set_name(output_names[0]);
526-
output_concat_layer.mutable_concat_param()->set_axis(-1);
523+
LayerParameter *outputs_layer = net_param.add_layer();
524+
outputs_layer->set_type("Concat");
525+
outputs_layer->add_bottom("fw_" + output_names[0]);
526+
outputs_layer->add_bottom("bw_rev_" + output_names[0]);
527+
outputs_layer->add_top(output_names[0]);
528+
outputs_layer->set_name(output_names[0]);
529+
outputs_layer->mutable_concat_param()->set_axis(-1);
527530
} else {
528531
LOG(ERROR)
529532
<< "The value of merge_mode of RNNv2 layer is not supported: "
@@ -585,6 +588,7 @@ void RNNv2Layer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
585588
}
586589
}
587590
} else {
591+
// bidirectional
588592
const int hidden_size = this->layer_param_.rnn_v2_param().hidden_size();
589593
const int input_size = bottom[0]->shape(2);
590594

@@ -601,89 +605,86 @@ void RNNv2Layer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
601605
}
602606
}
603607

604-
template <typename Dtype>
605-
void RNNv2Layer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
606-
const vector<Blob<Dtype> *> &top) {
607-
CHECK_GE(bottom[0]->num_axes(), 2)
608-
<< "bottom[0] must have at least 2 axes -- (#timesteps, #streams, ...)";
609-
CHECK_EQ(T_, bottom[0]->shape(0)) << "input number of timesteps changed";
610-
N_ = bottom[0]->shape(1);
611-
CHECK_EQ(bottom[1]->num_axes(), 2)
612-
<< "bottom[1] must have exactly 2 axes -- (#timesteps, #streams)";
613-
CHECK_EQ(T_, bottom[1]->shape(0));
614-
CHECK_EQ(N_, bottom[1]->shape(1));
615-
x_input_blob_->ReshapeLike(*bottom[0]);
616-
vector<int> cont_shape = bottom[1]->shape();
617-
cont_input_blob_->Reshape(cont_shape);
618-
vector<BlobShape> recur_input_shapes;
619-
RecurrentInputShapes(&recur_input_shapes);
620-
CHECK_EQ(recur_input_shapes.size(), recur_input_blobs_.size());
621-
for (int i = 0; i < recur_input_shapes.size(); ++i) {
622-
recur_input_blobs_[i]->Reshape(recur_input_shapes[i]);
623-
}
624-
unrolled_net_->Reshape();
625-
x_input_blob_->ShareData(*bottom[0]);
626-
x_input_blob_->ShareDiff(*bottom[0]);
627-
cont_input_blob_->ShareData(*bottom[1]);
628-
const int bottom_offset = 2;
629-
for (int i = bottom_offset, j = 0; i < bottom.size(); ++i, ++j) {
630-
CHECK(recur_input_blobs_[j]->shape() == bottom[i]->shape())
631-
<< "shape mismatch - recur_input_blobs_[" << j
632-
<< "]: " << recur_input_blobs_[j]->shape_string() << " vs. bottom[" << i
633-
<< "]: " << bottom[i]->shape_string();
634-
recur_input_blobs_[j]->ShareData(*bottom[i]);
635-
}
636-
for (int i = 0; i < output_blobs_.size(); ++i) {
637-
top[i]->ReshapeLike(*output_blobs_[i]);
638-
top[i]->ShareData(*output_blobs_[i]);
639-
top[i]->ShareDiff(*output_blobs_[i]);
640-
}
641-
const int top_offset = output_blobs_.size();
642-
for (int i = top_offset, j = 0; i < top.size(); ++i, ++j) {
643-
top[i]->ReshapeLike(*recur_output_blobs_[j]);
608+
template <typename Dtype>
609+
void RNNv2Layer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
610+
const vector<Blob<Dtype> *> &top) {
611+
CHECK_GE(bottom[0]->num_axes(), 2)
612+
<< "bottom[0] must have at least 2 axes -- (#timesteps, #streams, ...)";
613+
CHECK_EQ(T_, bottom[0]->shape(0)) << "input number of timesteps changed";
614+
N_ = bottom[0]->shape(1);
615+
CHECK_EQ(bottom[1]->num_axes(), 2)
616+
<< "bottom[1] must have exactly 2 axes -- (#timesteps, #streams)";
617+
CHECK_EQ(T_, bottom[1]->shape(0));
618+
CHECK_EQ(N_, bottom[1]->shape(1));
619+
x_input_blob_->ReshapeLike(*bottom[0]);
620+
vector<int> cont_shape = bottom[1]->shape();
621+
cont_input_blob_->Reshape(cont_shape);
622+
vector<BlobShape> recur_input_shapes;
623+
RecurrentInputShapes(&recur_input_shapes);
624+
CHECK_EQ(recur_input_shapes.size(), recur_input_blobs_.size());
625+
for (int i = 0; i < recur_input_shapes.size(); ++i) {
626+
recur_input_blobs_[i]->Reshape(recur_input_shapes[i]);
627+
}
628+
unrolled_net_->Reshape();
629+
x_input_blob_->ShareData(*bottom[0]);
630+
x_input_blob_->ShareDiff(*bottom[0]);
631+
cont_input_blob_->ShareData(*bottom[1]);
632+
const int bottom_offset = 2;
633+
for (int i = bottom_offset, j = 0; i < bottom.size(); ++i, ++j) {
634+
CHECK(recur_input_blobs_[j]->shape() == bottom[i]->shape())
635+
<< "shape mismatch - recur_input_blobs_[" << j
636+
<< "]: " << recur_input_blobs_[j]->shape_string() << " vs. bottom["
637+
<< i << "]: " << bottom[i]->shape_string();
638+
recur_input_blobs_[j]->ShareData(*bottom[i]);
639+
}
640+
for (int i = 0; i < output_blobs_.size(); ++i) {
641+
top[i]->ReshapeLike(*output_blobs_[i]);
642+
top[i]->ShareData(*output_blobs_[i]);
643+
top[i]->ShareDiff(*output_blobs_[i]);
644+
}
645+
const int top_offset = output_blobs_.size();
646+
for (int i = top_offset, j = 0; i < top.size(); ++i, ++j) {
647+
top[i]->ReshapeLike(*recur_output_blobs_[j]);
648+
}
644649
}
645-
}
646650

647-
template <typename Dtype> void RNNv2Layer<Dtype>::Reset() {
648-
// "Reset" the hidden state of the net by zeroing out all recurrent outputs.
649-
for (int i = 0; i < recur_output_blobs_.size(); ++i) {
650-
caffe_set(recur_output_blobs_[i]->count(), Dtype(0),
651-
recur_output_blobs_[i]->mutable_cpu_data());
652-
}
653-
}
654651

655-
template <typename Dtype>
656-
void RNNv2Layer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
657-
const vector<Blob<Dtype> *> &top) {
658-
// for bidirectional rnn, split the weights into two parts: forward
659-
// and reverse
660-
if (direction_ == "bidirectional") {
661-
const int blobs_size = this->blobs_.size();
662-
for (int i = 0; i < blobs_size; ++i) {
663-
if (this->blobs_[i]->count() % 2 != 0)
664-
LOG(FATAL) << "The total number of the weight blobs[" << i
665-
<< "] cannot be divided by 2";
666-
// weight blob for forward
667-
Dtype *f = unrolled_net_->params()[i]->mutable_cpu_data();
668-
// weight blob for backward
669-
Dtype *b = unrolled_net_->params()[i + blobs_size]->mutable_cpu_data();
670-
std::memcpy(f, this->blobs_[i]->cpu_data(),
671-
sizeof(Dtype) * this->blobs_[i]->count() / 2);
672-
std::memcpy(b, this->blobs_[i]->cpu_data() + this->blobs_[i]->count() / 2,
673-
sizeof(Dtype) * this->blobs_[i]->count() / 2);
652+
template <typename Dtype>
653+
void RNNv2Layer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
654+
const vector<Blob<Dtype> *> &top) {
655+
// for bidirectional rnn, split the weights into two parts: forward
656+
// and reverse
657+
if (direction_ == "bidirectional") {
658+
const int blobs_size = this->blobs_.size();
659+
CHECK_EQ(unrolled_net_->params().size() % 2, 0);
660+
const int params_size = unrolled_net_->params().size() / 2;
661+
662+
for (int i = 0; i < blobs_size; ++i) {
663+
if (this->blobs_[i]->count() % 2 != 0)
664+
LOG(FATAL) << "The total number of the weight blobs[" << i
665+
<< "] cannot be divided by 2";
666+
// weight blob for forward
667+
void *f = unrolled_net_->params()[i]->mutable_cpu_data();
668+
// weight blob for backward
669+
void *b = unrolled_net_->params()[i + params_size]->mutable_cpu_data();
670+
std::memcpy(f, this->blobs_[i]->cpu_data(),
671+
sizeof(Dtype) * this->blobs_[i]->count() / 2);
672+
std::memcpy(b,
673+
this->blobs_[i]->cpu_data() + this->blobs_[i]->count() / 2,
674+
sizeof(Dtype) * this->blobs_[i]->count() / 2);
675+
}
674676
}
675-
}
676677

677-
DCHECK_EQ(recur_input_blobs_.size(), recur_output_blobs_.size());
678+
DCHECK_EQ(recur_input_blobs_.size(), recur_output_blobs_.size());
678679

679-
unrolled_net_->ForwardTo(unrolled_net_->layers().size() - 1);
680+
unrolled_net_->ForwardTo(unrolled_net_->layers().size() - 1);
680681

681-
const int top_offset = output_blobs_.size();
682-
for (int i = top_offset, j = 0; i < top.size(); ++i, ++j) {
683-
top[i]->ShareData(*recur_output_blobs_[j]);
682+
const int top_offset = output_blobs_.size();
683+
for (int i = top_offset, j = 0; i < top.size(); ++i, ++j) {
684+
top[i]->ShareData(*recur_output_blobs_[j]);
685+
}
684686
}
685-
}
686687

687-
INSTANTIATE_CLASS(RNNv2Layer);
688-
REGISTER_LAYER_CLASS(RNNv2);
688+
INSTANTIATE_CLASS(RNNv2Layer);
689+
REGISTER_LAYER_CLASS(RNNv2);
689690
} // namespace caffe

0 commit comments

Comments
 (0)