Skip to content

Commit 6be273c

Browse files
committed
add seq mode lstm
1 parent 3636329 commit 6be273c

File tree

1 file changed

+45
-7
lines changed

1 file changed

+45
-7
lines changed

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/fusion_lstm_op.h"
1616
#include <string>
1717
#include "paddle/fluid/operators/math/blas.h"
18+
#include "paddle/fluid/operators/math/cpu_vec.h"
1819
#include "paddle/fluid/operators/math/detail/activation_functions.h"
1920
#include "paddle/fluid/operators/math/fc_compute.h"
2021
#include "paddle/fluid/operators/math/lstm_compute.h"
2122
#include "paddle/fluid/operators/math/sequence2batch.h"
23+
#include "paddle/fluid/platform/cpu_info.h"
24+
25+
DEFINE_bool(seq_mode, true, "Use sequence mode");
2226

2327
namespace paddle {
2428
namespace operators {
@@ -98,7 +102,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
98102
ctx->ShareLoD("X", "Hidden");
99103
ctx->ShareLoD("X", "Cell");
100104

101-
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
105+
int xx_width;
106+
if (FLAGS_seq_mode) {
107+
xx_width = wx_dims[1];
108+
} else {
109+
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
110+
}
102111
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
103112
ctx->ShareLoD("X", "XX");
104113
}
@@ -205,10 +214,34 @@ inline void ReorderInitState(const DeviceContext& ctx,
205214
row_shuffle(ctx, src, index_lod, dst, indexed_src);
206215
}
207216

208-
template <typename DeviceContext, typename T>
217+
template <typename T>
209218
class FuisonLSTMKernel : public framework::OpKernel<T> {
210219
public:
211-
void Compute(const framework::ExecutionContext& ctx) const override {
220+
void SeqCompute(const framework::ExecutionContext& ctx) const {
221+
using DeviceContext = paddle::platform::CPUDeviceContext;
222+
auto* x = ctx.Input<LoDTensor>("X");
223+
auto* wx = ctx.Input<Tensor>("WeightX");
224+
auto* wh = ctx.Input<Tensor>("WeightH");
225+
auto* bias = ctx.Input<Tensor>("Bias");
226+
227+
auto* xx = ctx.Output<LoDTensor>("XX");
228+
229+
auto x_dims = x->dims(); // T x M
230+
auto wh_dims = wh->dims(); // D x 4D
231+
const int M = x_dims[1]; // x frame size
232+
const int D4 = wh_dims[1];
233+
234+
const T* x_data = x->data<T>();
235+
const T* wx_data = wx->data<T>();
236+
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
237+
238+
auto blas = math::GetBlas<DeviceContext, T>(ctx);
239+
math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, x_data, wx_data,
240+
xx_data, bias->data<T>());
241+
}
242+
243+
void BatchCompute(const framework::ExecutionContext& ctx) const {
244+
using DeviceContext = platform::CPUDeviceContext;
212245
auto* x = ctx.Input<LoDTensor>("X");
213246
auto* wx = ctx.Input<Tensor>("WeightX");
214247
auto* wh = ctx.Input<Tensor>("WeightH");
@@ -339,6 +372,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
339372
// restore the output cell state in LoDTensor from the batch cell
340373
to_seq(dev_ctx, batch_cell, cell_out);
341374
}
375+
void Compute(const framework::ExecutionContext& ctx) const override {
376+
if (FLAGS_seq_mode) {
377+
SeqCompute(ctx);
378+
} else {
379+
BatchCompute(ctx);
380+
}
381+
}
342382
};
343383

344384
} // namespace operators
@@ -348,7 +388,5 @@ namespace ops = paddle::operators;
348388
REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker,
349389
paddle::framework::DefaultGradOpDescMaker<true>);
350390

351-
REGISTER_OP_CPU_KERNEL(
352-
fusion_lstm,
353-
ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, float>,
354-
ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, double>);
391+
REGISTER_OP_CPU_KERNEL(fusion_lstm, ops::FuisonLSTMKernel<float>,
392+
ops::FuisonLSTMKernel<double>);

0 commit comments

Comments
 (0)