Skip to content

Commit 4b28fab

Browse files
committed
enable more acts
1 parent 607c419 commit 4b28fab

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
230230
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
231231
auto* cell_out = ctx.Output<LoDTensor>("Cell");
232232

233+
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
234+
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
235+
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
236+
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
237+
if (platform::jit::MayIUse(platform::jit::avx)) {
238+
math::VecActivations<T, platform::jit::avx> act_functor;
239+
act_gate = act_functor(act_gate_str);
240+
act_cell = act_functor(act_cell_str);
241+
act_cand = act_functor(act_cand_str);
242+
} else {
243+
math::VecActivations<T, platform::jit::isa_any> act_functor;
244+
act_gate = act_functor(act_gate_str);
245+
act_cell = act_functor(act_cell_str);
246+
act_cand = act_functor(act_cand_str);
247+
}
248+
233249
auto x_lod = x->lod();
234250
auto x_dims = x->dims(); // T x M
235251
auto wh_dims = wh->dims(); // D x 4D
@@ -263,15 +279,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
263279
prev_cell_data = c0_data + i * D;
264280
} else {
265281
// W_ch, W_ih, W_fh, W_oh
266-
// actgate
267-
math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D);
268-
// ch gate
269-
math::vec_tanh<T>(D, xx_data, xx_data);
282+
act_gate(D3, xx_data + D, xx_data + D);
283+
act_cand(D, xx_data, xx_data);
270284
// cell out= input*tilde
271285
blas.VMUL(D, xx_data, xx_data + D, cell_out_data);
272286
// hidden out= act_state(cellout) * outgate
273-
// act state
274-
math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
287+
act_cell(D, cell_out_data, xx_data + D2);
275288
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
276289

277290
// prev
@@ -290,10 +303,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
290303
D4);
291304

292305
// W_ch, W_ih, W_fh, W_oh
293-
// actgate
294-
math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D);
295-
// ch gate
296-
math::vec_tanh<T>(D, xx_data, xx_data);
306+
act_gate(D3, xx_data + D, xx_data + D);
307+
act_cand(D, xx_data, xx_data);
297308

298309
// a = forget * prev_cell
299310
blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2);
@@ -305,8 +316,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
305316
blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data);
306317

307318
// hidden out= act_state(cellout) * outgate
308-
// act state
309-
math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
319+
act_cell(D, cell_out_data, xx_data + D2);
310320
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
311321

312322
// prev

python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def fusion_lstm(
4545

4646
class TestLstmOp(OpTest):
4747
def set_argument(self):
48-
self.lod = [[2, 3, 2]]
48+
pass
4949

5050
def setUp(self):
5151
self.op_type = 'fusion_lstm'

0 commit comments

Comments
 (0)