Skip to content

Commit 9131a35

Browse files
committed
replace the lstm compute with jitkernel
test=develop
1 parent b55c247 commit 9131a35

File tree

5 files changed

+22
-145
lines changed

5 files changed

+22
-145
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ op_library(flatten_op DEPS reshape_op)
299299
op_library(sequence_pad_op DEPS sequence_padding)
300300
op_library(unstack_op DEPS stack_op)
301301
op_library(fake_quantize_op DEPS memory)
302-
op_library(fusion_lstm_op DEPS cpu_lstm_compute)
302+
op_library(fusion_lstm_op DEPS jit_kernel)
303303
if (WITH_GPU)
304304
op_library(conv_op DEPS vol2col depthwise_conv im2col)
305305
op_library(layer_norm_op DEPS cub)

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/fusion_lstm_op.h"
1616
#include <string>
1717
#include "paddle/fluid/operators/math/blas.h"
18-
#include "paddle/fluid/operators/math/cpu_lstm_compute.h"
1918
#include "paddle/fluid/operators/math/cpu_vec.h"
2019
#include "paddle/fluid/operators/math/fc_compute.h"
20+
#include "paddle/fluid/operators/math/jit_kernel.h"
2121
#include "paddle/fluid/operators/math/sequence2batch.h"
2222
#include "paddle/fluid/platform/cpu_info.h"
2323

@@ -309,11 +309,6 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
309309
act_gate(D, gates + D3, gates + D3); \
310310
GET_Ht(ct, gates, ht)
311311

312-
#define COMPUTE_CtHt(gates, ct_1, ct, ht) \
313-
act_gate(D3, gates + D, gates + D); \
314-
GET_Ct(ct_1, gates, ct); \
315-
GET_Ht(ct, gates, ht)
316-
317312
#define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht) \
318313
/* get fgated and igated*/ \
319314
blas.VMUL(D, wc_data, ct_1, checked_cell_data); \
@@ -403,22 +398,18 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
403398
}
404399
}
405400
} else {
406-
// TODO(TJ): unly workaround, clean me
407-
std::function<void(T*, const T*, T*, T*)> compute_ctht;
408-
if (platform::jit::MayIUse(platform::jit::avx) &&
409-
act_gate_str == "sigmoid" && act_cand_str == "tanh" &&
410-
act_cell_str == "tanh" && D == 8) {
411-
compute_ctht = math::lstm_compute_ctht<T>;
412-
} else {
413-
compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) {
414-
COMPUTE_CtHt(gates, ct_1, ct, ht);
415-
};
416-
}
401+
const auto& ker =
402+
math::jitkernel::KernelPool::Instance()
403+
.template Get<math::jitkernel::LSTMKernel<T>, int,
404+
const std::string&, const std::string&,
405+
const std::string&>(D, act_gate_str, act_cand_str,
406+
act_cell_str);
407+
417408
for (int i = 0; i < N; ++i) {
418409
PROCESS_H0C0
419410
for (int step = tstart; step < seq_len; ++step) {
420411
GEMM_WH_ADDON(1, prev_h_data, xx_data);
421-
compute_ctht(xx_data, prev_c_data, c_out_data, h_out_data);
412+
ker->ComputeCtHt(xx_data, prev_c_data, c_out_data, h_out_data);
422413
MOVE_ONE_STEP;
423414
}
424415
}
@@ -552,24 +543,20 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
552543
MOVE_ONE_STEP;
553544
}
554545
} else {
555-
// TODO(TJ): unly workaround, clean me
556-
std::function<void(T*, const T*, T*, T*)> compute_ctht;
557-
if (platform::jit::MayIUse(platform::jit::avx) &&
558-
act_gate_str == "sigmoid" && act_cand_str == "tanh" &&
559-
act_cell_str == "tanh" && D == 8) {
560-
compute_ctht = math::lstm_compute_ctht<T>;
561-
} else {
562-
compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) {
563-
COMPUTE_CtHt(gates, ct_1, ct, ht);
564-
};
565-
}
546+
const auto& ker =
547+
math::jitkernel::KernelPool::Instance()
548+
.template Get<math::jitkernel::LSTMKernel<T>, int,
549+
const std::string&, const std::string&,
550+
const std::string&>(D, act_gate_str, act_cand_str,
551+
act_cell_str);
552+
566553
for (int step = tstart; step < max_seq_len; ++step) {
567554
const int cur_bs = batch_starts[step + 1] - batch_starts[step];
568555
GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data);
569556
DEFINE_CUR;
570557
for (int i = 0; i < cur_bs; ++i) {
571-
compute_ctht(cur_in_data, cur_prev_c_data, cur_c_out_data,
572-
cur_h_out_data);
558+
ker->ComputeCtHt(cur_in_data, cur_prev_c_data, cur_c_out_data,
559+
cur_h_out_data);
573560
MOVE_ONE_BATCH;
574561
}
575562
MOVE_ONE_STEP;
@@ -595,7 +582,6 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
595582
}
596583

597584
#undef COMPUTE_CtHt_PEEPHOLE
598-
#undef COMPUTE_CtHt
599585
#undef GET_Ct_NOH0C0
600586
#undef COMPUTE_CtHt_NOH0C0
601587
#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ math_library(im2col)
4545
if (NOT WIN32) # windows do not support avx functions yet.
4646
math_library(gru_compute DEPS activation_functions math_function)
4747
math_library(lstm_compute DEPS activation_functions)
48-
# TODO(TJ): ugly workaround, clean me
49-
cc_library(cpu_lstm_compute SRCS cpu_lstm_compute.cc DEPS activation_functions cblas cpu_info)
5048
endif (NOT WIN32)
5149

5250
cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context)
@@ -76,7 +74,7 @@ if(WITH_GPU)
7674
endif()
7775
cc_test(concat_test SRCS concat_test.cc DEPS concat)
7876
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
79-
cc_library(jit_kernel_exp SRCS jit_kernel_exp.cc DEPS cpu_info cblas activation_functions)
80-
cc_library(jit_kernel_lstm SRCS jit_kernel_lstm.cc DEPS cpu_info cblas activation_functions)
81-
cc_library(jit_kernel SRCS jit_kernel.cc jit_kernel_blas.cc DEPS cpu_info cblas jit_kernel_exp jit_kernel_lstm)
77+
cc_library(jit_kernel
78+
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc
79+
DEPS cpu_info cblas activation_functions)
8280
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)

paddle/fluid/operators/math/cpu_lstm_compute.cc

Lines changed: 0 additions & 43 deletions
This file was deleted.

paddle/fluid/operators/math/cpu_lstm_compute.h

Lines changed: 0 additions & 64 deletions
This file was deleted.

0 commit comments

Comments
 (0)