@@ -15,10 +15,14 @@ limitations under the License. */
15
15
#include " paddle/fluid/operators/fusion_lstm_op.h"
16
16
#include < string>
17
17
#include " paddle/fluid/operators/math/blas.h"
18
+ #include " paddle/fluid/operators/math/cpu_vec.h"
18
19
#include " paddle/fluid/operators/math/detail/activation_functions.h"
19
20
#include " paddle/fluid/operators/math/fc_compute.h"
20
21
#include " paddle/fluid/operators/math/lstm_compute.h"
21
22
#include " paddle/fluid/operators/math/sequence2batch.h"
23
+ #include " paddle/fluid/platform/cpu_info.h"
24
+
25
+ DEFINE_bool (seq_mode, true , " Use sequence mode" );
22
26
23
27
namespace paddle {
24
28
namespace operators {
@@ -98,7 +102,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
98
102
ctx->ShareLoD (" X" , " Hidden" );
99
103
ctx->ShareLoD (" X" , " Cell" );
100
104
101
- int xx_width = x_dims[1 ] > wx_dims[1 ] ? wx_dims[1 ] : x_dims[1 ];
105
+ int xx_width;
106
+ if (FLAGS_seq_mode) {
107
+ xx_width = wx_dims[1 ];
108
+ } else {
109
+ xx_width = x_dims[1 ] > wx_dims[1 ] ? wx_dims[1 ] : x_dims[1 ];
110
+ }
102
111
ctx->SetOutputDim (" XX" , {x_dims[0 ], xx_width});
103
112
ctx->ShareLoD (" X" , " XX" );
104
113
}
@@ -205,10 +214,138 @@ inline void ReorderInitState(const DeviceContext& ctx,
205
214
row_shuffle (ctx, src, index_lod, dst, indexed_src);
206
215
}
207
216
208
- template <typename DeviceContext, typename T>
217
+ template <typename T>
209
218
class FuisonLSTMKernel : public framework ::OpKernel<T> {
210
219
public:
211
- void Compute (const framework::ExecutionContext& ctx) const override {
220
+ void SeqCompute (const framework::ExecutionContext& ctx) const {
221
+ using DeviceContext = paddle::platform::CPUDeviceContext;
222
+ auto * x = ctx.Input <LoDTensor>(" X" );
223
+ auto * h0 = ctx.Input <Tensor>(" H0" );
224
+ auto * c0 = ctx.Input <Tensor>(" C0" );
225
+ auto * wx = ctx.Input <Tensor>(" WeightX" );
226
+ auto * wh = ctx.Input <Tensor>(" WeightH" );
227
+ auto * bias = ctx.Input <Tensor>(" Bias" );
228
+
229
+ auto * xx = ctx.Output <LoDTensor>(" XX" );
230
+ auto * hidden_out = ctx.Output <LoDTensor>(" Hidden" );
231
+ auto * cell_out = ctx.Output <LoDTensor>(" Cell" );
232
+ bool is_reverse = ctx.Attr <bool >(" is_reverse" );
233
+
234
+ std::function<void (const int , const T *, T *)> act_gate, act_cell, act_cand;
235
+ auto & act_gate_str = ctx.Attr <std::string>(" gate_activation" );
236
+ auto & act_cell_str = ctx.Attr <std::string>(" cell_activation" );
237
+ auto & act_cand_str = ctx.Attr <std::string>(" candidate_activation" );
238
+ if (platform::jit::MayIUse (platform::jit::avx)) {
239
+ math::VecActivations<T, platform::jit::avx> act_functor;
240
+ act_gate = act_functor (act_gate_str);
241
+ act_cell = act_functor (act_cell_str);
242
+ act_cand = act_functor (act_cand_str);
243
+ } else {
244
+ math::VecActivations<T, platform::jit::isa_any> act_functor;
245
+ act_gate = act_functor (act_gate_str);
246
+ act_cell = act_functor (act_cell_str);
247
+ act_cand = act_functor (act_cand_str);
248
+ }
249
+
250
+ auto x_lod = x->lod ();
251
+ auto x_dims = x->dims (); // T x M
252
+ auto wh_dims = wh->dims (); // D x 4D
253
+ const int total_T = x_dims[0 ];
254
+ const int N = x_lod[0 ].size () - 1 ; // batch size
255
+ const int M = x_dims[1 ]; // x frame size
256
+ const int D = wh_dims[0 ];
257
+ const int D2 = D * 2 ;
258
+ const int D3 = D * 3 ;
259
+ const int D4 = wh_dims[1 ];
260
+
261
+ const T* x_data = x->data <T>();
262
+ const T* h0_data = h0 ? h0->data <T>() : NULL ;
263
+ const T* c0_data = c0 ? c0->data <T>() : NULL ;
264
+ const T* wx_data = wx->data <T>();
265
+ const T* wh_data = wh->data <T>();
266
+ T* xx_data = xx->mutable_data <T>(ctx.GetPlace ());
267
+ T* hidden_out_data = hidden_out->mutable_data <T>(ctx.GetPlace ());
268
+ T* cell_out_data = cell_out->mutable_data <T>(ctx.GetPlace ());
269
+
270
+ auto blas = math::GetBlas<DeviceContext, T>(ctx);
271
+ math::FCCompute<DeviceContext, T>(blas, total_T, D4, M, x_data, wx_data,
272
+ xx_data, bias->data <T>());
273
+ int xx_offset = D4;
274
+ int gate_offset = D;
275
+ if (is_reverse) {
276
+ const int offset = (total_T - 1 ) * D;
277
+ xx_data = xx_data + offset * 4 ;
278
+ hidden_out_data = hidden_out_data + offset;
279
+ cell_out_data = cell_out_data + offset;
280
+ xx_offset = -D4;
281
+ gate_offset = -D;
282
+ }
283
+
284
+ auto move_step = [&]() {
285
+ xx_data = xx_data + xx_offset;
286
+ hidden_out_data = hidden_out_data + gate_offset;
287
+ cell_out_data = cell_out_data + gate_offset;
288
+ };
289
+
290
+ for (int i = 0 ; i < N; ++i) {
291
+ int bid = is_reverse ? N - 1 - i : i;
292
+ int seq_len = x_lod[0 ][bid + 1 ] - x_lod[0 ][bid];
293
+ const T* prev_cell_data = NULL ;
294
+ const T* prev_hidden_data = NULL ;
295
+ int tstart = 0 ;
296
+ if (h0_data) {
297
+ prev_hidden_data = h0_data + bid * D;
298
+ prev_cell_data = c0_data + bid * D;
299
+ } else {
300
+ // W_ch, W_ih, W_fh, W_oh
301
+ act_gate (D3, xx_data + D, xx_data + D);
302
+ act_cand (D, xx_data, xx_data);
303
+ // cell out= input*tilde
304
+ blas.VMUL (D, xx_data, xx_data + D, cell_out_data);
305
+ // hidden out= act_state(cellout) * outgate
306
+ act_cell (D, cell_out_data, xx_data + D2);
307
+ blas.VMUL (D, xx_data + D2, xx_data + D3, hidden_out_data);
308
+
309
+ // prev
310
+ prev_hidden_data = hidden_out_data;
311
+ prev_cell_data = cell_out_data;
312
+ tstart = 1 ;
313
+
314
+ move_step ();
315
+ }
316
+ for (int step = tstart; step < seq_len; ++step) {
317
+ blas.GEMM (CblasNoTrans, CblasNoTrans, 1 , D4, D, static_cast <T>(1 ),
318
+ prev_hidden_data, D, wh_data, D4, static_cast <T>(1 ), xx_data,
319
+ D4);
320
+
321
+ // W_ch, W_ih, W_fh, W_oh
322
+ act_gate (D3, xx_data + D, xx_data + D);
323
+ act_cand (D, xx_data, xx_data);
324
+
325
+ // a = forget * prev_cell
326
+ blas.VMUL (D, xx_data + D2, prev_cell_data, xx_data + D2);
327
+
328
+ // b = input * tilde
329
+ blas.VMUL (D, xx_data, xx_data + D, xx_data + D);
330
+
331
+ // cell out= a+b
332
+ blas.VADD (D, xx_data + D, xx_data + D2, cell_out_data);
333
+
334
+ // hidden out= act_state(cellout) * outgate
335
+ act_cell (D, cell_out_data, xx_data + D2);
336
+ blas.VMUL (D, xx_data + D2, xx_data + D3, hidden_out_data);
337
+
338
+ // prev
339
+ prev_hidden_data = hidden_out_data;
340
+ prev_cell_data = cell_out_data;
341
+
342
+ move_step ();
343
+ }
344
+ }
345
+ }
346
+
347
+ void BatchCompute (const framework::ExecutionContext& ctx) const {
348
+ using DeviceContext = platform::CPUDeviceContext;
212
349
auto * x = ctx.Input <LoDTensor>(" X" );
213
350
auto * wx = ctx.Input <Tensor>(" WeightX" );
214
351
auto * wh = ctx.Input <Tensor>(" WeightH" );
@@ -339,6 +476,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
339
476
// restore the output cell state in LoDTensor from the batch cell
340
477
to_seq (dev_ctx, batch_cell, cell_out);
341
478
}
479
+ void Compute (const framework::ExecutionContext& ctx) const override {
480
+ if (FLAGS_seq_mode) {
481
+ SeqCompute (ctx);
482
+ } else {
483
+ BatchCompute (ctx);
484
+ }
485
+ }
342
486
};
343
487
344
488
} // namespace operators
@@ -348,7 +492,5 @@ namespace ops = paddle::operators;
348
492
REGISTER_OPERATOR (fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker,
349
493
paddle::framework::DefaultGradOpDescMaker<true >);
350
494
351
- REGISTER_OP_CPU_KERNEL (
352
- fusion_lstm,
353
- ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, float >,
354
- ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, double >);
495
+ REGISTER_OP_CPU_KERNEL (fusion_lstm, ops::FuisonLSTMKernel<float >,
496
+ ops::FuisonLSTMKernel<double >);
0 commit comments