@@ -15,9 +15,9 @@ limitations under the License. */
15
15
#include " paddle/fluid/operators/fusion_lstm_op.h"
16
16
#include < string>
17
17
#include " paddle/fluid/operators/math/blas.h"
18
- #include " paddle/fluid/operators/math/cpu_lstm_compute.h"
19
18
#include " paddle/fluid/operators/math/cpu_vec.h"
20
19
#include " paddle/fluid/operators/math/fc_compute.h"
20
+ #include " paddle/fluid/operators/math/jit_kernel.h"
21
21
#include " paddle/fluid/operators/math/sequence2batch.h"
22
22
#include " paddle/fluid/platform/cpu_info.h"
23
23
@@ -309,11 +309,6 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
309
309
act_gate (D, gates + D3, gates + D3); \
310
310
GET_Ht (ct, gates, ht)
311
311
312
- #define COMPUTE_CtHt (gates, ct_1, ct, ht ) \
313
- act_gate (D3, gates + D, gates + D); \
314
- GET_Ct (ct_1, gates, ct); \
315
- GET_Ht (ct, gates, ht)
316
-
317
312
#define COMPUTE_CtHt_PEEPHOLE (gates, ct_1, ct, ht ) \
318
313
/* get fgated and igated*/ \
319
314
blas.VMUL(D, wc_data, ct_1, checked_cell_data); \
@@ -403,22 +398,18 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
403
398
}
404
399
}
405
400
} else {
406
- // TODO(TJ): unly workaround, clean me
407
- std::function<void (T*, const T*, T*, T*)> compute_ctht;
408
- if (platform::jit::MayIUse (platform::jit::avx) &&
409
- act_gate_str == " sigmoid" && act_cand_str == " tanh" &&
410
- act_cell_str == " tanh" && D == 8 ) {
411
- compute_ctht = math::lstm_compute_ctht<T>;
412
- } else {
413
- compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) {
414
- COMPUTE_CtHt (gates, ct_1, ct, ht);
415
- };
416
- }
401
+ const auto & ker =
402
+ math::jitkernel::KernelPool::Instance ()
403
+ .template Get <math::jitkernel::LSTMKernel<T>, int ,
404
+ const std::string&, const std::string&,
405
+ const std::string&>(D, act_gate_str, act_cand_str,
406
+ act_cell_str);
407
+
417
408
for (int i = 0 ; i < N; ++i) {
418
409
PROCESS_H0C0
419
410
for (int step = tstart; step < seq_len; ++step) {
420
411
GEMM_WH_ADDON (1 , prev_h_data, xx_data);
421
- compute_ctht (xx_data, prev_c_data, c_out_data, h_out_data);
412
+ ker-> ComputeCtHt (xx_data, prev_c_data, c_out_data, h_out_data);
422
413
MOVE_ONE_STEP;
423
414
}
424
415
}
@@ -552,24 +543,20 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
552
543
MOVE_ONE_STEP;
553
544
}
554
545
} else {
555
- // TODO(TJ): unly workaround, clean me
556
- std::function<void (T*, const T*, T*, T*)> compute_ctht;
557
- if (platform::jit::MayIUse (platform::jit::avx) &&
558
- act_gate_str == " sigmoid" && act_cand_str == " tanh" &&
559
- act_cell_str == " tanh" && D == 8 ) {
560
- compute_ctht = math::lstm_compute_ctht<T>;
561
- } else {
562
- compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) {
563
- COMPUTE_CtHt (gates, ct_1, ct, ht);
564
- };
565
- }
546
+ const auto & ker =
547
+ math::jitkernel::KernelPool::Instance ()
548
+ .template Get <math::jitkernel::LSTMKernel<T>, int ,
549
+ const std::string&, const std::string&,
550
+ const std::string&>(D, act_gate_str, act_cand_str,
551
+ act_cell_str);
552
+
566
553
for (int step = tstart; step < max_seq_len; ++step) {
567
554
const int cur_bs = batch_starts[step + 1 ] - batch_starts[step];
568
555
GEMM_WH_ADDON (cur_bs, prev_h_data, batched_input_data);
569
556
DEFINE_CUR;
570
557
for (int i = 0 ; i < cur_bs; ++i) {
571
- compute_ctht (cur_in_data, cur_prev_c_data, cur_c_out_data,
572
- cur_h_out_data);
558
+ ker-> ComputeCtHt (cur_in_data, cur_prev_c_data, cur_c_out_data,
559
+ cur_h_out_data);
573
560
MOVE_ONE_BATCH;
574
561
}
575
562
MOVE_ONE_STEP;
@@ -595,7 +582,6 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
595
582
}
596
583
597
584
#undef COMPUTE_CtHt_PEEPHOLE
598
- #undef COMPUTE_CtHt
599
585
#undef GET_Ct_NOH0C0
600
586
#undef COMPUTE_CtHt_NOH0C0
601
587
#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0
0 commit comments