Skip to content

Commit 64fe9bc

Browse files
committed
Update lstm comments and fix bug.
1 parent 34aac18 commit 64fe9bc

File tree

5 files changed

+16
-15
lines changed

5 files changed

+16
-15
lines changed

paddle/framework/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ proto_library(framework_proto SRCS framework.proto)
2020

2121
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
2222
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info)
23-
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc)
23+
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
24+
device_context)
2425
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
2526
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
2627
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)

paddle/operators/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
127127
op_library(sum_op DEPS net_op)
128128
op_library(pool_op DEPS pooling)
129129
op_library(pool_with_index_op DEPS pooling)
130-
op_library(lstm_op DEPS sequence2batch lstm_compute math_function)
130+
op_library(lstm_op DEPS sequence2batch lstm_compute)
131131

132132
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
133133
foreach(src ${GENERAL_OPS})

paddle/operators/lstm_op.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,18 +98,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
9898
"batch size. `H0` and `C0` can be NULL but only at the same time");
9999
AddInput("Weight",
100100
"(Tensor) the learnable hidden-hidden weights."
101-
" - The shape is (D x 4*D), where D is the hidden size. "
102-
" - Weight = {W_ih, W_fh, W_ch, W_oh}");
101+
" - The shape is (D x 4D), where D is the hidden size. "
102+
" - Weight = {W_ch, W_ih, W_fh, W_oh}");
103103
AddInput("Bias",
104104
"(Tensor) the learnable weights, which contains two parts: "
105105
"input-hidden bias weight and peephole connections weight if "
106-
"seting `usePeepholes` True. "
106+
"setting `usePeepholes` True. "
107107
"1. `usePeepholes = False` "
108-
" - The shape is (1 x 4*D). "
109-
" - Bias = {b_i, b_f, b_c, b_o}."
108+
" - The shape is (1 x 4D). "
109+
" - Bias = {b_c, b_i, b_f, b_o}."
110110
"2. `usePeepholes = True` "
111-
" - The shape is (1 x 7*D). "
112-
" - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}.");
111+
" - The shape is (1 x 7D). "
112+
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
113113
AddOutput("BatchGate",
114114
"(LoDTensor) This LoDTensor contains input gate, forget gate "
115115
"and output gate after the nonlinear computation. This "
@@ -184,8 +184,8 @@ Set `usePeepholes` False to disable peephole connection [2]. The formula
184184
is omitted here.
185185
186186
@note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$
187-
operations on the input x_{t} were NOT included in this operator. The
188-
users can choose to use fully-connect operator before LSTM operator.
187+
operations on the input x_{t} were NOT included in this operator.
188+
Users can choose to use fully-connect operator before LSTM operator.
189189
190190
[1] Hasim Sak, Andrew Senior, and Francoise Beaufays. Long short-term memory
191191
recurrent neural network architectures for large scale acoustic modeling.

paddle/operators/lstm_op.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,12 @@ class LSTMKernel : public framework::OpKernel<T> {
7676
lstm_value.checkOg = lstm_value.checkFg + frame_size;
7777
lstm_value.prevStateValue = nullptr;
7878

79-
framework::LoDTensor batch_out;
79+
framework::LoDTensor batch_out, batch_cell, batch_cell_pre_act;
8080
batch_out.mutable_data<T>(dims, ctx.GetPlace());
81-
framework::LoDTensor batch_cell;
8281
batch_cell.mutable_data<T>(dims, ctx.GetPlace());
83-
framework::LoDTensor batch_cell_pre_act;
8482
batch_cell_pre_act.mutable_data<T>(dims, ctx.GetPlace());
8583

86-
auto& batch_starts = batch_gate->lod()[0];
84+
auto batch_starts = batch_gate->lod()[0];
8785
size_t num_batch = batch_starts.size() - 1;
8886
auto gate_act = ctx.Attr<std::string>("gateActivation");
8987
auto cell_act = ctx.Attr<std::string>("cellActivation");

paddle/operators/math/sequence2batch.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class CopyMatrixRowsFunctor<platform::CPUPlace, T> {
5151
template class CopyMatrixRowsFunctor<platform::CPUPlace, float>;
5252
template class CopyMatrixRowsFunctor<platform::CPUPlace, double>;
5353

54+
template class LoDTensor2BatchFunctor<platform::CPUPlace, float>;
55+
template class LoDTensor2BatchFunctor<platform::CPUPlace, double>;
5456
template class Batch2LoDTensorFunctor<platform::CPUPlace, float>;
5557
template class Batch2LoDTensorFunctor<platform::CPUPlace, double>;
5658

0 commit comments

Comments
 (0)