@@ -215,46 +215,53 @@ This operator fuse the X into LSTM, more details can refer to LSTM op.
215
215
template <typename T>
216
216
class FuisonLSTMKernel : public framework ::OpKernel<T> {
217
217
public:
218
+ #define INIT_VEC_FUNC \
219
+ std::function<void (const int , const T *, T *)> act_gate, act_cell, act_cand; \
220
+ auto & act_gate_str = ctx.Attr<std::string>(" gate_activation" ); \
221
+ auto & act_cell_str = ctx.Attr<std::string>(" cell_activation" ); \
222
+ auto & act_cand_str = ctx.Attr<std::string>(" candidate_activation" ); \
223
+ if (platform::jit::MayIUse(platform::jit::avx)) { \
224
+ math::VecActivations<T, platform::jit::avx> act_functor; \
225
+ act_gate = act_functor (act_gate_str); \
226
+ act_cell = act_functor (act_cell_str); \
227
+ act_cand = act_functor (act_cand_str); \
228
+ } else { \
229
+ math::VecActivations<T, platform::jit::isa_any> act_functor; \
230
+ act_gate = act_functor (act_gate_str); \
231
+ act_cell = act_functor (act_cell_str); \
232
+ act_cand = act_functor (act_cand_str); \
233
+ }
234
+
235
+ #define INIT_BASE_INPUT_OUTPUT \
236
+ auto * x = ctx.Input<LoDTensor>(" X" ); \
237
+ auto * h0 = ctx.Input<Tensor>(" H0" ); \
238
+ auto * c0 = ctx.Input<Tensor>(" C0" ); \
239
+ auto * wx = ctx.Input<Tensor>(" WeightX" ); \
240
+ auto * wh = ctx.Input<Tensor>(" WeightH" ); \
241
+ auto * bias = ctx.Input<Tensor>(" Bias" ); \
242
+ auto * xx = ctx.Output<LoDTensor>(" XX" ); \
243
+ auto * hidden_out = ctx.Output<LoDTensor>(" Hidden" ); \
244
+ auto * cell_out = ctx.Output<LoDTensor>(" Cell" ); \
245
+ bool is_reverse = ctx.Attr<bool >(" is_reverse" );
246
+
247
+ #define INIT_BASE_SIZES \
248
+ auto x_dims = x->dims (); /* T x M*/ \
249
+ auto wh_dims = wh->dims (); /* D x 4D*/ \
250
+ const int M = x_dims[1 ]; \
251
+ const int D = wh_dims[0 ]; \
252
+ const int D2 = D * 2 ; \
253
+ const int D3 = D * 3 ; \
254
+ const int D4 = wh_dims[1 ];
255
+
218
256
void SeqCompute (const framework::ExecutionContext& ctx) const {
219
257
using DeviceContext = paddle::platform::CPUDeviceContext;
220
- auto * x = ctx.Input <LoDTensor>(" X" );
221
- auto * h0 = ctx.Input <Tensor>(" H0" );
222
- auto * c0 = ctx.Input <Tensor>(" C0" );
223
- auto * wx = ctx.Input <Tensor>(" WeightX" );
224
- auto * wh = ctx.Input <Tensor>(" WeightH" );
225
- auto * bias = ctx.Input <Tensor>(" Bias" );
226
-
227
- auto * xx = ctx.Output <LoDTensor>(" XX" );
228
- auto * hidden_out = ctx.Output <LoDTensor>(" Hidden" );
229
- auto * cell_out = ctx.Output <LoDTensor>(" Cell" );
230
- bool is_reverse = ctx.Attr <bool >(" is_reverse" );
231
-
232
- std::function<void (const int , const T *, T *)> act_gate, act_cell, act_cand;
233
- auto & act_gate_str = ctx.Attr <std::string>(" gate_activation" );
234
- auto & act_cell_str = ctx.Attr <std::string>(" cell_activation" );
235
- auto & act_cand_str = ctx.Attr <std::string>(" candidate_activation" );
236
- if (platform::jit::MayIUse (platform::jit::avx)) {
237
- math::VecActivations<T, platform::jit::avx> act_functor;
238
- act_gate = act_functor (act_gate_str);
239
- act_cell = act_functor (act_cell_str);
240
- act_cand = act_functor (act_cand_str);
241
- } else {
242
- math::VecActivations<T, platform::jit::isa_any> act_functor;
243
- act_gate = act_functor (act_gate_str);
244
- act_cell = act_functor (act_cell_str);
245
- act_cand = act_functor (act_cand_str);
246
- }
258
+ INIT_BASE_INPUT_OUTPUT
259
+ INIT_BASE_SIZES
260
+ INIT_VEC_FUNC
247
261
248
262
auto x_lod = x->lod ();
249
- auto x_dims = x->dims (); // T x M
250
- auto wh_dims = wh->dims (); // D x 4D
251
263
const int total_T = x_dims[0 ];
252
264
const int N = x_lod[0 ].size () - 1 ; // batch size
253
- const int M = x_dims[1 ]; // x frame size
254
- const int D = wh_dims[0 ];
255
- const int D2 = D * 2 ;
256
- const int D3 = D * 3 ;
257
- const int D4 = wh_dims[1 ];
258
265
259
266
const T* x_data = x->data <T>();
260
267
const T* h0_data = h0 ? h0->data <T>() : NULL ;
@@ -343,52 +350,18 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
343
350
344
351
void BatchCompute (const framework::ExecutionContext& ctx) const {
345
352
using DeviceContext = platform::CPUDeviceContext;
346
- auto * x = ctx.Input <LoDTensor>(" X" );
347
- auto * wx = ctx.Input <Tensor>(" WeightX" );
348
- auto * wh = ctx.Input <Tensor>(" WeightH" );
349
- auto * bias = ctx.Input <Tensor>(" Bias" );
350
- auto * h0 = ctx.Input <Tensor>(" H0" );
351
- auto * c0 = ctx.Input <Tensor>(" C0" );
352
-
353
- auto * xx = ctx.Output <LoDTensor>(" XX" );
353
+ INIT_BASE_INPUT_OUTPUT
354
+ if (x->lod ()[0 ].size () == 2 ) { // batch size == 1
355
+ SeqCompute (ctx);
356
+ }
357
+ INIT_BASE_SIZES
358
+ INIT_VEC_FUNC
359
+
354
360
auto * reordered_h0 = ctx.Output <Tensor>(" ReorderedH0" );
355
361
auto * reordered_c0 = ctx.Output <Tensor>(" ReorderedC0" );
356
362
auto * batched_input = ctx.Output <LoDTensor>(" BatchedInput" );
357
363
auto * batched_c_out = ctx.Output <LoDTensor>(" BatchedCell" );
358
364
auto * batched_h_out = ctx.Output <LoDTensor>(" BatchedHidden" );
359
- auto * hidden_out = ctx.Output <LoDTensor>(" Hidden" );
360
- auto * cell_out = ctx.Output <LoDTensor>(" Cell" );
361
- bool is_reverse = ctx.Attr <bool >(" is_reverse" );
362
-
363
- std::function<void (const int , const T *, T *)> act_gate, act_cell, act_cand;
364
- auto & act_gate_str = ctx.Attr <std::string>(" gate_activation" );
365
- auto & act_cell_str = ctx.Attr <std::string>(" cell_activation" );
366
- auto & act_cand_str = ctx.Attr <std::string>(" candidate_activation" );
367
- if (platform::jit::MayIUse (platform::jit::avx)) {
368
- math::VecActivations<T, platform::jit::avx> act_functor;
369
- act_gate = act_functor (act_gate_str);
370
- act_cell = act_functor (act_cell_str);
371
- act_cand = act_functor (act_cand_str);
372
- } else {
373
- math::VecActivations<T, platform::jit::isa_any> act_functor;
374
- act_gate = act_functor (act_gate_str);
375
- act_cell = act_functor (act_cell_str);
376
- act_cand = act_functor (act_cand_str);
377
- }
378
-
379
- auto x_dims = x->dims (); // T x M
380
- auto wh_dims = wh->dims (); // D x 4D
381
-
382
- // auto x_lod = x->lod();
383
- // const int N = x_lod[0].size() - 1; // batch size
384
- // if (N == 1) {
385
- // SeqCompute(ctx);
386
- // }
387
- const int M = x_dims[1 ];
388
- const int D = wh_dims[0 ];
389
- const int D2 = D * 2 ;
390
- const int D3 = D * 3 ;
391
- const int D4 = wh_dims[1 ];
392
365
393
366
const T* x_data = x->data <T>();
394
367
const T* wx_data = wx->data <T>();
@@ -485,16 +458,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
485
458
// W_ch, W_ih, W_fh, W_oh
486
459
act_gate (D3, cur_in_data + D, cur_in_data + D);
487
460
act_cand (D, cur_in_data, cur_in_data);
488
-
489
461
// a = forget * prev_cell
490
462
blas.VMUL (D, cur_in_data + D2, cur_prev_c_data, cur_in_data + D2);
491
-
492
463
// b = input * tilde
493
464
blas.VMUL (D, cur_in_data, cur_in_data + D, cur_in_data + D);
494
-
495
465
// cell out= a+b
496
466
blas.VADD (D, cur_in_data + D, cur_in_data + D2, cur_c_out_data);
497
-
498
467
// hidden out= act_state(cellout) * outgate
499
468
act_cell (D, cur_c_out_data, cur_in_data + D2);
500
469
blas.VMUL (D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data);
@@ -526,6 +495,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
526
495
BatchCompute (ctx);
527
496
}
528
497
}
498
+ #undef INIT_BASE_SIZES
499
+ #undef INIT_BASE_INPUT_OUTPUT
500
+ #undef INIT_VEC_FUNC
529
501
};
530
502
531
503
} // namespace operators
0 commit comments