Skip to content

Commit a79a77e

Browse files
committed
refine and clean code
1 parent c459fb5 commit a79a77e

File tree

1 file changed

+51
-79
lines changed

1 file changed

+51
-79
lines changed

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 51 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -215,46 +215,53 @@ This operator fuse the X into LSTM, more details can refer to LSTM op.
215215
template <typename T>
216216
class FuisonLSTMKernel : public framework::OpKernel<T> {
217217
public:
218+
#define INIT_VEC_FUNC \
219+
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; \
220+
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
221+
auto& act_cell_str = ctx.Attr<std::string>("cell_activation"); \
222+
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation"); \
223+
if (platform::jit::MayIUse(platform::jit::avx)) { \
224+
math::VecActivations<T, platform::jit::avx> act_functor; \
225+
act_gate = act_functor(act_gate_str); \
226+
act_cell = act_functor(act_cell_str); \
227+
act_cand = act_functor(act_cand_str); \
228+
} else { \
229+
math::VecActivations<T, platform::jit::isa_any> act_functor; \
230+
act_gate = act_functor(act_gate_str); \
231+
act_cell = act_functor(act_cell_str); \
232+
act_cand = act_functor(act_cand_str); \
233+
}
234+
235+
#define INIT_BASE_INPUT_OUTPUT \
236+
auto* x = ctx.Input<LoDTensor>("X"); \
237+
auto* h0 = ctx.Input<Tensor>("H0"); \
238+
auto* c0 = ctx.Input<Tensor>("C0"); \
239+
auto* wx = ctx.Input<Tensor>("WeightX"); \
240+
auto* wh = ctx.Input<Tensor>("WeightH"); \
241+
auto* bias = ctx.Input<Tensor>("Bias"); \
242+
auto* xx = ctx.Output<LoDTensor>("XX"); \
243+
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
244+
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
245+
bool is_reverse = ctx.Attr<bool>("is_reverse");
246+
247+
#define INIT_BASE_SIZES \
248+
auto x_dims = x->dims(); /* T x M*/ \
249+
auto wh_dims = wh->dims(); /* D x 4D*/ \
250+
const int M = x_dims[1]; \
251+
const int D = wh_dims[0]; \
252+
const int D2 = D * 2; \
253+
const int D3 = D * 3; \
254+
const int D4 = wh_dims[1];
255+
218256
void SeqCompute(const framework::ExecutionContext& ctx) const {
219257
using DeviceContext = paddle::platform::CPUDeviceContext;
220-
auto* x = ctx.Input<LoDTensor>("X");
221-
auto* h0 = ctx.Input<Tensor>("H0");
222-
auto* c0 = ctx.Input<Tensor>("C0");
223-
auto* wx = ctx.Input<Tensor>("WeightX");
224-
auto* wh = ctx.Input<Tensor>("WeightH");
225-
auto* bias = ctx.Input<Tensor>("Bias");
226-
227-
auto* xx = ctx.Output<LoDTensor>("XX");
228-
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
229-
auto* cell_out = ctx.Output<LoDTensor>("Cell");
230-
bool is_reverse = ctx.Attr<bool>("is_reverse");
231-
232-
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
233-
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
234-
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
235-
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
236-
if (platform::jit::MayIUse(platform::jit::avx)) {
237-
math::VecActivations<T, platform::jit::avx> act_functor;
238-
act_gate = act_functor(act_gate_str);
239-
act_cell = act_functor(act_cell_str);
240-
act_cand = act_functor(act_cand_str);
241-
} else {
242-
math::VecActivations<T, platform::jit::isa_any> act_functor;
243-
act_gate = act_functor(act_gate_str);
244-
act_cell = act_functor(act_cell_str);
245-
act_cand = act_functor(act_cand_str);
246-
}
258+
INIT_BASE_INPUT_OUTPUT
259+
INIT_BASE_SIZES
260+
INIT_VEC_FUNC
247261

248262
auto x_lod = x->lod();
249-
auto x_dims = x->dims(); // T x M
250-
auto wh_dims = wh->dims(); // D x 4D
251263
const int total_T = x_dims[0];
252264
const int N = x_lod[0].size() - 1; // batch size
253-
const int M = x_dims[1]; // x frame size
254-
const int D = wh_dims[0];
255-
const int D2 = D * 2;
256-
const int D3 = D * 3;
257-
const int D4 = wh_dims[1];
258265

259266
const T* x_data = x->data<T>();
260267
const T* h0_data = h0 ? h0->data<T>() : NULL;
@@ -343,52 +350,18 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
343350

344351
void BatchCompute(const framework::ExecutionContext& ctx) const {
345352
using DeviceContext = platform::CPUDeviceContext;
346-
auto* x = ctx.Input<LoDTensor>("X");
347-
auto* wx = ctx.Input<Tensor>("WeightX");
348-
auto* wh = ctx.Input<Tensor>("WeightH");
349-
auto* bias = ctx.Input<Tensor>("Bias");
350-
auto* h0 = ctx.Input<Tensor>("H0");
351-
auto* c0 = ctx.Input<Tensor>("C0");
352-
353-
auto* xx = ctx.Output<LoDTensor>("XX");
353+
INIT_BASE_INPUT_OUTPUT
354+
if (x->lod()[0].size() == 2) { // batch size == 1
355+
SeqCompute(ctx);
356+
}
357+
INIT_BASE_SIZES
358+
INIT_VEC_FUNC
359+
354360
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
355361
auto* reordered_c0 = ctx.Output<Tensor>("ReorderedC0");
356362
auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
357363
auto* batched_c_out = ctx.Output<LoDTensor>("BatchedCell");
358364
auto* batched_h_out = ctx.Output<LoDTensor>("BatchedHidden");
359-
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
360-
auto* cell_out = ctx.Output<LoDTensor>("Cell");
361-
bool is_reverse = ctx.Attr<bool>("is_reverse");
362-
363-
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
364-
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
365-
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
366-
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
367-
if (platform::jit::MayIUse(platform::jit::avx)) {
368-
math::VecActivations<T, platform::jit::avx> act_functor;
369-
act_gate = act_functor(act_gate_str);
370-
act_cell = act_functor(act_cell_str);
371-
act_cand = act_functor(act_cand_str);
372-
} else {
373-
math::VecActivations<T, platform::jit::isa_any> act_functor;
374-
act_gate = act_functor(act_gate_str);
375-
act_cell = act_functor(act_cell_str);
376-
act_cand = act_functor(act_cand_str);
377-
}
378-
379-
auto x_dims = x->dims(); // T x M
380-
auto wh_dims = wh->dims(); // D x 4D
381-
382-
// auto x_lod = x->lod();
383-
// const int N = x_lod[0].size() - 1; // batch size
384-
// if (N == 1) {
385-
// SeqCompute(ctx);
386-
// }
387-
const int M = x_dims[1];
388-
const int D = wh_dims[0];
389-
const int D2 = D * 2;
390-
const int D3 = D * 3;
391-
const int D4 = wh_dims[1];
392365

393366
const T* x_data = x->data<T>();
394367
const T* wx_data = wx->data<T>();
@@ -485,16 +458,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
485458
// W_ch, W_ih, W_fh, W_oh
486459
act_gate(D3, cur_in_data + D, cur_in_data + D);
487460
act_cand(D, cur_in_data, cur_in_data);
488-
489461
// a = forget * prev_cell
490462
blas.VMUL(D, cur_in_data + D2, cur_prev_c_data, cur_in_data + D2);
491-
492463
// b = input * tilde
493464
blas.VMUL(D, cur_in_data, cur_in_data + D, cur_in_data + D);
494-
495465
// cell out= a+b
496466
blas.VADD(D, cur_in_data + D, cur_in_data + D2, cur_c_out_data);
497-
498467
// hidden out= act_state(cellout) * outgate
499468
act_cell(D, cur_c_out_data, cur_in_data + D2);
500469
blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data);
@@ -526,6 +495,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
526495
BatchCompute(ctx);
527496
}
528497
}
498+
#undef INIT_BASE_SIZES
499+
#undef INIT_BASE_INPUT_OUTPUT
500+
#undef INIT_VEC_FUNC
529501
};
530502

531503
} // namespace operators

0 commit comments

Comments
 (0)