Skip to content

Commit 3f1062d

Browse files
authored
Merge pull request #4929 from qingqing01/lstm
Forward implementation for LSTM operator.
2 parents 154e1d0 + cf2608e commit 3f1062d

25 files changed

+2351
-6
lines changed

paddle/operators/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ set(DEPS_OPS
115115
softmax_with_cross_entropy_op
116116
sum_op
117117
pool_op
118-
pool_with_index_op)
118+
pool_with_index_op
119+
lstm_op)
119120

120121

121122
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
@@ -126,6 +127,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
126127
op_library(sum_op DEPS net_op)
127128
op_library(pool_op DEPS pooling)
128129
op_library(pool_with_index_op DEPS pooling)
130+
op_library(lstm_op DEPS sequence2batch lstm_compute)
129131

130132
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
131133
foreach(src ${GENERAL_OPS})

paddle/operators/lstm_op.cc

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/lstm_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class LSTMOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
protected:
25+
void InferShape(framework::InferShapeContext* ctx) const override {
26+
PADDLE_ENFORCE(ctx->HasInput("Input"),
27+
"Input(Input) of LSTM should not be null.");
28+
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
29+
"Output(Hidden) of LSTM should not be null.");
30+
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
31+
"Output(Cell) of LSTM should not be null.");
32+
33+
auto x_dims = ctx->GetInputDim("Input");
34+
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
35+
36+
if (ctx->HasInput("H0")) {
37+
PADDLE_ENFORCE(ctx->HasInput("C0"),
38+
"Input(Cell) and Input(Hidden) of LSTM should not "
39+
"be null at the same time.");
40+
auto h_dims = ctx->GetInputDim("H0");
41+
auto c_dims = ctx->GetInputDim("C0");
42+
PADDLE_ENFORCE(h_dims == c_dims,
43+
"The dimension of Input(H0) and Input(C0) "
44+
"should be the same.");
45+
}
46+
47+
int frame_size = x_dims[1] / 4;
48+
auto w_dims = ctx->GetInputDim("Weight");
49+
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
50+
"The rank of Input(Weight) should be 2.");
51+
PADDLE_ENFORCE_EQ(w_dims[0], frame_size,
52+
"The first dimension of Input(Weight) "
53+
"should be %d.",
54+
frame_size);
55+
PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size,
56+
"The second dimension of Input(Weight) "
57+
"should be 4 * %d.",
58+
frame_size);
59+
auto b_dims = ctx->GetInputDim("Bias");
60+
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
61+
PADDLE_ENFORCE_EQ(b_dims[0], 1,
62+
"The first dimension of Input(Bias) should be 1.");
63+
if (ctx->Attrs().Get<bool>("usePeepholes")) {
64+
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
65+
"The second dimension of Input(Bias) should be "
66+
"7 * %d if enable peepholes connection",
67+
frame_size);
68+
} else {
69+
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
70+
"The second dimension of Input(Bias) should be "
71+
"4 * %d if disable peepholes connection",
72+
frame_size);
73+
}
74+
ctx->SetOutputDim("Hidden", {x_dims[0], frame_size});
75+
ctx->SetOutputDim("Cell", {x_dims[0], frame_size});
76+
ctx->SetOutputDim("BatchGate", x_dims);
77+
ctx->ShareLoD("Input", "Hidden");
78+
ctx->ShareLoD("Input", "Cell");
79+
}
80+
};
81+
82+
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
83+
public:
84+
LSTMOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
85+
: OpProtoAndCheckerMaker(proto, op_checker) {
86+
AddInput("Input",
87+
"(LoDTensor) the first input is a LodTensor, which support "
88+
"variable-time length input sequence. The underlying tensor in "
89+
"this LoDTensor is a matrix with shape (T X 4D), where, T is the "
90+
"total time steps in this mini-batch, D is the hidden size.");
91+
AddInput("H0",
92+
"(Tensor, optional) the initial hidden state is an optional "
93+
"input. This is a tensor with shape (N x D), where N is the "
94+
"batch size, D is the hidden size.");
95+
AddInput("C0",
96+
"(Tensor, optional) the initial cell state is an optional "
97+
"input. This is a tensor with shape (N x D), where N is the "
98+
"batch size. `H0` and `C0` can be NULL but only at the same time");
99+
AddInput("Weight",
100+
"(Tensor) the learnable hidden-hidden weights."
101+
" - The shape is (D x 4D), where D is the hidden size. "
102+
" - Weight = {W_ch, W_ih, W_fh, W_oh}");
103+
AddInput("Bias",
104+
"(Tensor) the learnable weights, which contains two parts: "
105+
"input-hidden bias weight and peephole connections weight if "
106+
"setting `usePeepholes` True. "
107+
"1. `usePeepholes = False` "
108+
" - The shape is (1 x 4D). "
109+
" - Bias = {b_c, b_i, b_f, b_o}."
110+
"2. `usePeepholes = True` "
111+
" - The shape is (1 x 7D). "
112+
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
113+
AddOutput("BatchGate",
114+
"(LoDTensor) This LoDTensor contains input gate, forget gate "
115+
"and output gate after the nonlinear computation. This "
116+
"LoDTensor has the same shape with the reorganized input, which "
117+
"was also be called batch input. The LoD size is 2. The first "
118+
"LoD is the batch offsets and the second LoD contains the "
119+
"indexes, which denote the position of reorganized sequence "
120+
"in the raw input.")
121+
.AsIntermediate();
122+
AddOutput("Hidden",
123+
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
124+
"The shape and lod is the same with the `Input`.");
125+
AddOutput("Cell",
126+
"(LoDTensor) the cell state lod tensor of LSTM operator. "
127+
"The shape and lod is the same with the `Input`.");
128+
AddAttr<bool>("usePeepholes",
129+
"(bool, defalut: True) "
130+
"whether to enable diagonal/peephole connections.")
131+
.SetDefault(true);
132+
AddAttr<bool>("isReverse",
133+
"(bool, defalut: False) "
134+
"whether to compute reversed LSTM.")
135+
.SetDefault(false);
136+
AddAttr<std::string>(
137+
"gateActivation",
138+
"(string, default: sigmoid)"
139+
"The activation for input gate, forget gate and output "
140+
"gate, `sigmoid` by default.")
141+
.SetDefault("sigmoid");
142+
AddAttr<std::string>("cellActivation",
143+
"(string, default: tanh)"
144+
"The activation for cell output, `tanh` by defalut.")
145+
.SetDefault("tanh");
146+
AddAttr<std::string>("candidateActivation",
147+
"(string, default: tanh)"
148+
"The activation for candidate hidden state, "
149+
"`tanh` by default.")
150+
.SetDefault("tanh");
151+
AddComment(R"DOC(Long-Short Term Memory (LSTM) Operator
152+
153+
The defalut implementation is diagonal/peephole connection [1], the formula is
154+
as follows
155+
156+
i_t = \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + W_{ic}c_{t-1} + b_i)
157+
158+
f_t = \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + W_{fc}c_{t-1} + b_f)
159+
160+
\tilde{c_t} = act_g(W_{cx}x_t + W_{ch}h_{t-1} + b_c)
161+
162+
o_t = \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + W_{oc}c_t + b_o)
163+
164+
c_t = f_t ⊙ c_{t-1} + i_t ⊙ \tilde{c_t}
165+
166+
h_t = o_t ⊙ act_h(c_t)
167+
168+
where the W terms denote weight matrices (e.g. \f$W_{xi}\f$ is the matrix
169+
of weights from the input gate to the input), \f$W_{ic}, W_{fc}, W_{oc}\f$
170+
are diagonal weight matrices for peephole connections. In our implenmention,
171+
We use vectors to reprenset these diagonal weight matrices. The b terms
172+
denote bias vectors (\f$b_i\f$ is the input gate bias vector), \f$\sigma\f$
173+
is the non-line actications, such as logistic sigmoid function, and
174+
\f$i, f, o\f$ and \f$c\f$ are respectively the input gate, forget gate,
175+
output gate and cell activation vectors, all of which are the same size as
176+
the cell output activation vector \f$h\f$.
177+
178+
The ⊙ is the element-wise product of the vectors, \f$act_g\f$ and \f$act_h\f$
179+
are the cell input and cell output activation functions, `tanh` is usually
180+
used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state,
181+
which is computed based on the current input and the previous hidden state.
182+
183+
Set `usePeepholes` False to disable peephole connection [2]. The formula
184+
is omitted here.
185+
186+
@note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$
187+
operations on the input x_{t} were NOT included in this operator.
188+
Users can choose to use fully-connect operator before LSTM operator.
189+
190+
[1] Hasim Sak, Andrew Senior, and Francoise Beaufays. Long short-term memory
191+
recurrent neural network architectures for large scale acoustic modeling.
192+
INTERSPEECH, 2014.
193+
194+
[2] S. Hochreiter and J. Schmidhuber. Long Short-Term Memory.
195+
Neural Computation, 9(8):1735-1780, 1997.
196+
197+
)DOC");
198+
}
199+
};
200+
201+
class LSTMGradOp : public framework::OperatorWithKernel {
202+
public:
203+
using framework::OperatorWithKernel::OperatorWithKernel;
204+
205+
protected:
206+
void InferShape(framework::InferShapeContext* ctx) const override {
207+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
208+
"Input(Hidden@GRAD) should not be null");
209+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")),
210+
"Input(Cell@GRAD) should not be null");
211+
ctx->SetOutputDim(framework::GradVarName("Weight"),
212+
ctx->GetInputDim("Weight"));
213+
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias"));
214+
}
215+
};
216+
217+
} // namespace operators
218+
} // namespace paddle
219+
220+
namespace ops = paddle::operators;
221+
REGISTER_OP(lstm, ops::LSTMOp, ops::LSTMOpMaker, lstm_grad, ops::LSTMGradOp);
222+
REGISTER_OP_CPU_KERNEL(lstm, ops::LSTMKernel<paddle::platform::CPUPlace, float>,
223+
ops::LSTMKernel<paddle::platform::CPUPlace, double>);
224+
REGISTER_OP_CPU_KERNEL(lstm_grad,
225+
ops::LSTMGradKernel<paddle::platform::CPUPlace, float>,
226+
ops::LSTMGradKernel<paddle::platform::CPUPlace, double>);

paddle/operators/lstm_op.cu

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
#include "paddle/operators/lstm_op.h"
17+
18+
namespace ops = paddle::operators;
19+
REGISTER_OP_GPU_KERNEL(lstm, ops::LSTMKernel<paddle::platform::GPUPlace, float>,
20+
ops::LSTMKernel<paddle::platform::GPUPlace, double>);
21+
REGISTER_OP_GPU_KERNEL(lstm_grad,
22+
ops::LSTMGradKernel<paddle::platform::GPUPlace, float>,
23+
ops::LSTMGradKernel<paddle::platform::GPUPlace, double>);

0 commit comments

Comments
 (0)