Skip to content

Commit 7fd22a5

Browse files
committed
reformat RNNv2
1 parent 3adafa7 commit 7fd22a5

File tree

1 file changed

+72
-72
lines changed

1 file changed

+72
-72
lines changed

src/caffe/layers/rnn_v2_layer.cpp

Lines changed: 72 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -605,86 +605,86 @@ void RNNv2Layer<Dtype>::LayerSetUp(const vector<Blob<Dtype> *> &bottom,
605605
}
606606
}
607607

608-
template <typename Dtype>
609-
void RNNv2Layer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
608+
template <typename Dtype>
609+
void RNNv2Layer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
610610
const vector<Blob<Dtype> *> &top) {
611-
CHECK_GE(bottom[0]->num_axes(), 2)
612-
<< "bottom[0] must have at least 2 axes -- (#timesteps, #streams, ...)";
613-
CHECK_EQ(T_, bottom[0]->shape(0)) << "input number of timesteps changed";
614-
N_ = bottom[0]->shape(1);
615-
CHECK_EQ(bottom[1]->num_axes(), 2)
616-
<< "bottom[1] must have exactly 2 axes -- (#timesteps, #streams)";
617-
CHECK_EQ(T_, bottom[1]->shape(0));
618-
CHECK_EQ(N_, bottom[1]->shape(1));
619-
x_input_blob_->ReshapeLike(*bottom[0]);
620-
vector<int> cont_shape = bottom[1]->shape();
621-
cont_input_blob_->Reshape(cont_shape);
622-
vector<BlobShape> recur_input_shapes;
623-
RecurrentInputShapes(&recur_input_shapes);
624-
CHECK_EQ(recur_input_shapes.size(), recur_input_blobs_.size());
625-
for (int i = 0; i < recur_input_shapes.size(); ++i) {
626-
recur_input_blobs_[i]->Reshape(recur_input_shapes[i]);
627-
}
628-
unrolled_net_->Reshape();
629-
x_input_blob_->ShareData(*bottom[0]);
630-
x_input_blob_->ShareDiff(*bottom[0]);
631-
cont_input_blob_->ShareData(*bottom[1]);
632-
const int bottom_offset = 2;
633-
for (int i = bottom_offset, j = 0; i < bottom.size(); ++i, ++j) {
634-
CHECK(recur_input_blobs_[j]->shape() == bottom[i]->shape())
635-
<< "shape mismatch - recur_input_blobs_[" << j
636-
<< "]: " << recur_input_blobs_[j]->shape_string() << " vs. bottom["
637-
<< i << "]: " << bottom[i]->shape_string();
638-
recur_input_blobs_[j]->ShareData(*bottom[i]);
639-
}
640-
for (int i = 0; i < output_blobs_.size(); ++i) {
641-
top[i]->ReshapeLike(*output_blobs_[i]);
642-
top[i]->ShareData(*output_blobs_[i]);
643-
top[i]->ShareDiff(*output_blobs_[i]);
644-
}
645-
const int top_offset = output_blobs_.size();
646-
for (int i = top_offset, j = 0; i < top.size(); ++i, ++j) {
647-
top[i]->ReshapeLike(*recur_output_blobs_[j]);
648-
}
611+
CHECK_GE(bottom[0]->num_axes(), 2)
612+
<< "bottom[0] must have at least 2 axes -- (#timesteps, #streams, ...)";
613+
CHECK_EQ(T_, bottom[0]->shape(0)) << "input number of timesteps changed";
614+
N_ = bottom[0]->shape(1);
615+
CHECK_EQ(bottom[1]->num_axes(), 2)
616+
<< "bottom[1] must have exactly 2 axes -- (#timesteps, #streams)";
617+
CHECK_EQ(T_, bottom[1]->shape(0));
618+
CHECK_EQ(N_, bottom[1]->shape(1));
619+
x_input_blob_->ReshapeLike(*bottom[0]);
620+
vector<int> cont_shape = bottom[1]->shape();
621+
cont_input_blob_->Reshape(cont_shape);
622+
vector<BlobShape> recur_input_shapes;
623+
RecurrentInputShapes(&recur_input_shapes);
624+
CHECK_EQ(recur_input_shapes.size(), recur_input_blobs_.size());
625+
for (int i = 0; i < recur_input_shapes.size(); ++i) {
626+
recur_input_blobs_[i]->Reshape(recur_input_shapes[i]);
627+
}
628+
unrolled_net_->Reshape();
629+
x_input_blob_->ShareData(*bottom[0]);
630+
x_input_blob_->ShareDiff(*bottom[0]);
631+
cont_input_blob_->ShareData(*bottom[1]);
632+
const int bottom_offset = 2;
633+
for (int i = bottom_offset, j = 0; i < bottom.size(); ++i, ++j) {
634+
CHECK(recur_input_blobs_[j]->shape() == bottom[i]->shape())
635+
<< "shape mismatch - recur_input_blobs_[" << j
636+
<< "]: " << recur_input_blobs_[j]->shape_string() << " vs. bottom["
637+
<< i << "]: " << bottom[i]->shape_string();
638+
recur_input_blobs_[j]->ShareData(*bottom[i]);
639+
}
640+
for (int i = 0; i < output_blobs_.size(); ++i) {
641+
top[i]->ReshapeLike(*output_blobs_[i]);
642+
top[i]->ShareData(*output_blobs_[i]);
643+
top[i]->ShareDiff(*output_blobs_[i]);
649644
}
645+
const int top_offset = output_blobs_.size();
646+
for (int i = top_offset, j = 0; i < top.size(); ++i, ++j) {
647+
top[i]->ReshapeLike(*recur_output_blobs_[j]);
648+
}
649+
}
650650

651651

652-
template <typename Dtype>
653-
void RNNv2Layer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
654-
const vector<Blob<Dtype> *> &top) {
655-
// for bidirectional rnn, split the weights into two parts: forward
656-
// and reverse
657-
if (direction_ == "bidirectional") {
658-
const int blobs_size = this->blobs_.size();
659-
CHECK_EQ(unrolled_net_->params().size() % 2, 0);
660-
const int params_size = unrolled_net_->params().size() / 2;
661-
662-
for (int i = 0; i < blobs_size; ++i) {
663-
if (this->blobs_[i]->count() % 2 != 0)
664-
LOG(FATAL) << "The total number of the weight blobs[" << i
665-
<< "] cannot be divided by 2";
666-
// weight blob for forward
667-
void *f = unrolled_net_->params()[i]->mutable_cpu_data();
668-
// weight blob for backward
669-
void *b = unrolled_net_->params()[i + params_size]->mutable_cpu_data();
670-
std::memcpy(f, this->blobs_[i]->cpu_data(),
671-
sizeof(Dtype) * this->blobs_[i]->count() / 2);
672-
std::memcpy(b,
673-
this->blobs_[i]->cpu_data() + this->blobs_[i]->count() / 2,
674-
sizeof(Dtype) * this->blobs_[i]->count() / 2);
675-
}
652+
template <typename Dtype>
653+
void RNNv2Layer<Dtype>::Forward_cpu(const vector<Blob<Dtype> *> &bottom,
654+
const vector<Blob<Dtype> *> &top) {
655+
// for bidirectional rnn, split the weights into two parts: forward
656+
// and reverse
657+
if (direction_ == "bidirectional") {
658+
const int blobs_size = this->blobs_.size();
659+
CHECK_EQ(unrolled_net_->params().size() % 2, 0);
660+
const int params_size = unrolled_net_->params().size() / 2;
661+
662+
for (int i = 0; i < blobs_size; ++i) {
663+
if (this->blobs_[i]->count() % 2 != 0)
664+
LOG(FATAL) << "The total number of the weight blobs[" << i
665+
<< "] cannot be divided by 2";
666+
// weight blob for forward
667+
void *f = unrolled_net_->params()[i]->mutable_cpu_data();
668+
// weight blob for backward
669+
void *b = unrolled_net_->params()[i + params_size]->mutable_cpu_data();
670+
std::memcpy(f, this->blobs_[i]->cpu_data(),
671+
sizeof(Dtype) * this->blobs_[i]->count() / 2);
672+
std::memcpy(b,
673+
this->blobs_[i]->cpu_data() + this->blobs_[i]->count() / 2,
674+
sizeof(Dtype) * this->blobs_[i]->count() / 2);
676675
}
676+
}
677677

678-
DCHECK_EQ(recur_input_blobs_.size(), recur_output_blobs_.size());
678+
DCHECK_EQ(recur_input_blobs_.size(), recur_output_blobs_.size());
679679

680-
unrolled_net_->ForwardTo(unrolled_net_->layers().size() - 1);
680+
unrolled_net_->ForwardTo(unrolled_net_->layers().size() - 1);
681681

682-
const int top_offset = output_blobs_.size();
683-
for (int i = top_offset, j = 0; i < top.size(); ++i, ++j) {
684-
top[i]->ShareData(*recur_output_blobs_[j]);
685-
}
682+
const int top_offset = output_blobs_.size();
683+
for (int i = top_offset, j = 0; i < top.size(); ++i, ++j) {
684+
top[i]->ShareData(*recur_output_blobs_[j]);
686685
}
686+
}
687687

688-
INSTANTIATE_CLASS(RNNv2Layer);
689-
REGISTER_LAYER_CLASS(RNNv2);
688+
INSTANTIATE_CLASS(RNNv2Layer);
689+
REGISTER_LAYER_CLASS(RNNv2);
690690
} // namespace caffe

0 commit comments

Comments
 (0)