@@ -15,10 +15,14 @@ limitations under the License. */
15
15
#include " paddle/fluid/operators/fusion_lstm_op.h"
16
16
#include < string>
17
17
#include " paddle/fluid/operators/math/blas.h"
18
+ #include " paddle/fluid/operators/math/cpu_vec.h"
18
19
#include " paddle/fluid/operators/math/detail/activation_functions.h"
19
20
#include " paddle/fluid/operators/math/fc_compute.h"
20
21
#include " paddle/fluid/operators/math/lstm_compute.h"
21
22
#include " paddle/fluid/operators/math/sequence2batch.h"
23
+ #include " paddle/fluid/platform/cpu_info.h"
24
+
25
+ DEFINE_bool (seq_mode, true , " Use sequence mode" );
22
26
23
27
namespace paddle {
24
28
namespace operators {
@@ -98,7 +102,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
98
102
ctx->ShareLoD (" X" , " Hidden" );
99
103
ctx->ShareLoD (" X" , " Cell" );
100
104
101
- int xx_width = x_dims[1 ] > wx_dims[1 ] ? wx_dims[1 ] : x_dims[1 ];
105
+ int xx_width;
106
+ if (FLAGS_seq_mode) {
107
+ xx_width = wx_dims[1 ];
108
+ } else {
109
+ xx_width = x_dims[1 ] > wx_dims[1 ] ? wx_dims[1 ] : x_dims[1 ];
110
+ }
102
111
ctx->SetOutputDim (" XX" , {x_dims[0 ], xx_width});
103
112
ctx->ShareLoD (" X" , " XX" );
104
113
}
@@ -205,10 +214,34 @@ inline void ReorderInitState(const DeviceContext& ctx,
205
214
row_shuffle (ctx, src, index_lod, dst, indexed_src);
206
215
}
207
216
208
- template <typename DeviceContext, typename T>
217
+ template <typename T>
209
218
class FuisonLSTMKernel : public framework ::OpKernel<T> {
210
219
public:
211
- void Compute (const framework::ExecutionContext& ctx) const override {
220
+ void SeqCompute (const framework::ExecutionContext& ctx) const {
221
+ using DeviceContext = paddle::platform::CPUDeviceContext;
222
+ auto * x = ctx.Input <LoDTensor>(" X" );
223
+ auto * wx = ctx.Input <Tensor>(" WeightX" );
224
+ auto * wh = ctx.Input <Tensor>(" WeightH" );
225
+ auto * bias = ctx.Input <Tensor>(" Bias" );
226
+
227
+ auto * xx = ctx.Output <LoDTensor>(" XX" );
228
+
229
+ auto x_dims = x->dims (); // T x M
230
+ auto wh_dims = wh->dims (); // D x 4D
231
+ const int M = x_dims[1 ]; // x frame size
232
+ const int D4 = wh_dims[1 ];
233
+
234
+ const T* x_data = x->data <T>();
235
+ const T* wx_data = wx->data <T>();
236
+ T* xx_data = xx->mutable_data <T>(ctx.GetPlace ());
237
+
238
+ auto blas = math::GetBlas<DeviceContext, T>(ctx);
239
+ math::FCCompute<DeviceContext, T>(blas, x_dims[0 ], D4, M, x_data, wx_data,
240
+ xx_data, bias->data <T>());
241
+ }
242
+
243
+ void BatchCompute (const framework::ExecutionContext& ctx) const {
244
+ using DeviceContext = platform::CPUDeviceContext;
212
245
auto * x = ctx.Input <LoDTensor>(" X" );
213
246
auto * wx = ctx.Input <Tensor>(" WeightX" );
214
247
auto * wh = ctx.Input <Tensor>(" WeightH" );
@@ -339,6 +372,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
339
372
// restore the output cell state in LoDTensor from the batch cell
340
373
to_seq (dev_ctx, batch_cell, cell_out);
341
374
}
375
+ void Compute (const framework::ExecutionContext& ctx) const override {
376
+ if (FLAGS_seq_mode) {
377
+ SeqCompute (ctx);
378
+ } else {
379
+ BatchCompute (ctx);
380
+ }
381
+ }
342
382
};
343
383
344
384
} // namespace operators
@@ -348,7 +388,5 @@ namespace ops = paddle::operators;
348
388
REGISTER_OPERATOR (fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker,
349
389
paddle::framework::DefaultGradOpDescMaker<true >);
350
390
351
- REGISTER_OP_CPU_KERNEL (
352
- fusion_lstm,
353
- ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, float >,
354
- ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, double >);
391
+ REGISTER_OP_CPU_KERNEL (fusion_lstm, ops::FuisonLSTMKernel<float >,
392
+ ops::FuisonLSTMKernel<double >);
0 commit comments