@@ -230,6 +230,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
230
230
auto * hidden_out = ctx.Output <LoDTensor>(" Hidden" );
231
231
auto * cell_out = ctx.Output <LoDTensor>(" Cell" );
232
232
233
+ std::function<void (const int , const T *, T *)> act_gate, act_cell, act_cand;
234
+ auto & act_gate_str = ctx.Attr <std::string>(" gate_activation" );
235
+ auto & act_cell_str = ctx.Attr <std::string>(" cell_activation" );
236
+ auto & act_cand_str = ctx.Attr <std::string>(" candidate_activation" );
237
+ if (platform::jit::MayIUse (platform::jit::avx)) {
238
+ math::VecActivations<T, platform::jit::avx> act_functor;
239
+ act_gate = act_functor (act_gate_str);
240
+ act_cell = act_functor (act_cell_str);
241
+ act_cand = act_functor (act_cand_str);
242
+ } else {
243
+ math::VecActivations<T, platform::jit::isa_any> act_functor;
244
+ act_gate = act_functor (act_gate_str);
245
+ act_cell = act_functor (act_cell_str);
246
+ act_cand = act_functor (act_cand_str);
247
+ }
248
+
233
249
auto x_lod = x->lod ();
234
250
auto x_dims = x->dims (); // T x M
235
251
auto wh_dims = wh->dims (); // D x 4D
@@ -263,15 +279,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
263
279
prev_cell_data = c0_data + i * D;
264
280
} else {
265
281
// W_ch, W_ih, W_fh, W_oh
266
- // actgate
267
- math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D);
268
- // ch gate
269
- math::vec_tanh<T>(D, xx_data, xx_data);
282
+ act_gate (D3, xx_data + D, xx_data + D);
283
+ act_cand (D, xx_data, xx_data);
270
284
// cell out= input*tilde
271
285
blas.VMUL (D, xx_data, xx_data + D, cell_out_data);
272
286
// hidden out= act_state(cellout) * outgate
273
- // act state
274
- math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
287
+ act_cell (D, cell_out_data, xx_data + D2);
275
288
blas.VMUL (D, xx_data + D2, xx_data + D3, hidden_out_data);
276
289
277
290
// prev
@@ -290,10 +303,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
290
303
D4);
291
304
292
305
// W_ch, W_ih, W_fh, W_oh
293
- // actgate
294
- math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D);
295
- // ch gate
296
- math::vec_tanh<T>(D, xx_data, xx_data);
306
+ act_gate (D3, xx_data + D, xx_data + D);
307
+ act_cand (D, xx_data, xx_data);
297
308
298
309
// a = forget * prev_cell
299
310
blas.VMUL (D, xx_data + D2, prev_cell_data, xx_data + D2);
@@ -305,8 +316,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
305
316
blas.VADD (D, xx_data + D, xx_data + D2, cell_out_data);
306
317
307
318
// hidden out= act_state(cellout) * outgate
308
- // act state
309
- math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
319
+ act_cell (D, cell_out_data, xx_data + D2);
310
320
blas.VMUL (D, xx_data + D2, xx_data + D3, hidden_out_data);
311
321
312
322
// prev
0 commit comments