Skip to content

Commit 18aa786

Browse files
committed
RNNv2 major update. Still be unstable
1 parent 55d5e9a commit 18aa786

File tree

2 files changed

+206
-152
lines changed

2 files changed

+206
-152
lines changed

include/caffe/layers/rnn_v2_layer.hpp

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ template <typename Dtype> class RNNv2Layer : public Layer<Dtype> {
2727
virtual inline int MinBottomBlobs() const {
2828
int min_bottoms = 2;
2929
vector<string> inputs;
30-
this->RecurrentInputBlobNames(&inputs);
30+
this->RecurrentBlobNamePrefix(&inputs);
3131
min_bottoms += inputs.size();
3232
return min_bottoms;
3333
}
3434
virtual inline int MaxBottomBlobs() const { return MinBottomBlobs() + 1; }
3535
virtual inline int ExactNumTopBlobs() const {
3636
int num_tops = 1;
3737
vector<string> outputs;
38-
this->RecurrentOutputBlobNames(&outputs);
38+
this->RecurrentBlobNamePrefix(&outputs);
3939
num_tops += outputs.size();
4040
return num_tops;
4141
}
@@ -45,18 +45,19 @@ template <typename Dtype> class RNNv2Layer : public Layer<Dtype> {
4545
* @brief Fills net_param with the recurrent network architecture. Subclasses
4646
* should define this -- see RNNLayer for examples.
4747
*/
48-
void FillUnrolledNet(NetParameter *net_param, string x_name, string cont_name,
49-
vector<string> recur_input_names,
48+
void FillUnrolledNet(NetParameter *net_param,
49+
const string x_name,
50+
const string cont_name,
5051
vector<string> output_names,
51-
vector<string> recur_output_names,
52-
const string &name_prefix);
52+
vector<string> recur_name_prefix,
53+
const string &layer_name_prefix);
5354

5455
/**
5556
* @brief Fills names with the names of the 0th timestep recurrent input
5657
* Blob&s. Subclasses should define this -- see RNNLayer and LSTMLayer
5758
* for examples.
5859
*/
59-
void RecurrentInputBlobNames(vector<string> *names) const;
60+
void RecurrentBlobNamePrefix(vector<string> *names) const;
6061

6162
/**
6263
* @brief Fills shapes with the shapes of the recurrent input Blob&s.
@@ -65,13 +66,6 @@ template <typename Dtype> class RNNv2Layer : public Layer<Dtype> {
6566
*/
6667
void RecurrentInputShapes(vector<BlobShape> *shapes) const;
6768

68-
/**
69-
* @brief Fills names with the names of the Tth timestep recurrent output
70-
* Blob&s. Subclasses should define this -- see RNNLayer and LSTMLayer
71-
* for examples.
72-
*/
73-
void RecurrentOutputBlobNames(vector<string> *names) const;
74-
7569
/**
7670
* @brief Fills names with the names of the output blobs, concatenated across
7771
* all timesteps. Should return a name for each top Blob.

0 commit comments

Comments
 (0)