Skip to content

Commit 7ef2699

Browse files
committed
init peephole runtime kernel
1 parent 3ee8f2c commit 7ef2699

File tree

4 files changed

+104
-34
lines changed

4 files changed

+104
-34
lines changed

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -400,10 +400,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
400400
} else {
401401
const auto& ker =
402402
math::jitkernel::KernelPool::Instance()
403-
.template Get<math::jitkernel::LSTMKernel<T>, int,
404-
const std::string&, const std::string&,
405-
const std::string&>(D, act_gate_str, act_cand_str,
406-
act_cell_str);
403+
.template Get<math::jitkernel::LSTMKernel<T>, const std::string&,
404+
const std::string&, const std::string&>(
405+
act_gate_str, act_cand_str, act_cell_str, D, false);
407406

408407
for (int i = 0; i < N; ++i) {
409408
PROCESS_H0C0
@@ -545,10 +544,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
545544
} else {
546545
const auto& ker =
547546
math::jitkernel::KernelPool::Instance()
548-
.template Get<math::jitkernel::LSTMKernel<T>, int,
549-
const std::string&, const std::string&,
550-
const std::string&>(D, act_gate_str, act_cand_str,
551-
act_cell_str);
547+
.template Get<math::jitkernel::LSTMKernel<T>, const std::string&,
548+
const std::string&, const std::string&>(
549+
act_gate_str, act_cand_str, act_cell_str, D, false);
552550

553551
for (int step = tstart; step < max_seq_len; ++step) {
554552
const int cur_bs = batch_starts[step + 1] - batch_starts[step];

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ class VTanhKernel : public VActKernel<T> {
125125
template <typename T>
126126
class LSTMKernel : public Kernel {
127127
public:
128-
virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht) const = 0;
128+
virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht,
129+
T *checked = nullptr) const = 0;
129130
};
130131

131132
} // namespace jitkernel

paddle/fluid/operators/math/jit_kernel_lstm.cc

Lines changed: 84 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ __m256 AVXActImpl<kIdentity>::Compute(__m256 x) const {
8686
template <typename T, jit::cpu_isa_t isa, jit_block>
8787
class LSTMKernelImpl : public LSTMKernel<T> {
8888
public:
89-
explicit LSTMKernelImpl(int d, const std::string& act_gate,
89+
explicit LSTMKernelImpl(const std::string& act_gate,
9090
const std::string& act_cand,
91-
const std::string& act_cell)
91+
const std::string& act_cell, int d)
9292
: LSTMKernel<T>() {
9393
d_ = d;
9494
d2_ = d * 2;
@@ -134,7 +134,8 @@ class LSTMKernelImpl : public LSTMKernel<T> {
134134
#endif
135135
}
136136

137-
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht) const override {
137+
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht,
138+
T* checked) const override {
138139
// gates: W_ch, W_ih, W_fh, W_oh
139140
act_gate_3d_->Compute(gates + d_, gates + d_);
140141

@@ -162,7 +163,8 @@ class LSTMKernelImpl : public LSTMKernel<T> {
162163
#define INTRI8_FLOAT(isa) \
163164
template <> \
164165
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
165-
float* gates, const float* ct_1, float* ct, float* ht) const { \
166+
float* gates, const float* ct_1, float* ct, float* ht, float* checked) \
167+
const { \
166168
/* gates: W_ch, W_ih, W_fh, W_oh */ \
167169
__m256 c, i, f, o; \
168170
c = _mm256_loadu_ps(gates); \
@@ -192,21 +194,86 @@ INTRI8_FLOAT(jit::avx2);
192194
INTRI8_FLOAT(jit::avx512f);
193195
#endif
194196

195-
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
196-
template <> \
197-
std::shared_ptr<const ker_class<ker_dtype>> \
198-
KernelPool::Get<ker_class<ker_dtype>, int, const std::string&, \
199-
const std::string&, const std::string&>( \
200-
int d, const std::string& act_gate, const std::string& act_cand, \
201-
const std::string& act_cell)
197+
/* Peephole JitKernel */
198+
template <typename T, jit::cpu_isa_t isa, jit_block>
199+
class PeepholeKernelImpl : public LSTMKernel<T> {
200+
public:
201+
explicit PeepholeKernelImpl(const std::string& act_gate,
202+
const std::string& act_cand,
203+
const std::string& act_cell, int d)
204+
: LSTMKernel<T>() {
205+
d_ = d;
206+
d2_ = d * 2;
207+
d3_ = d * 3;
208+
auto GetActKernel = [&](const std::string& type,
209+
int n) -> std::shared_ptr<const VActKernel<T>> {
210+
if (type == "sigmoid") {
211+
return std::dynamic_pointer_cast<const VActKernel<T>>(
212+
KernelPool::Instance().template Get<VSigmoidKernel<T>>(n));
213+
} else if (type == "relu") {
214+
return std::dynamic_pointer_cast<const VActKernel<T>>(
215+
KernelPool::Instance().template Get<VReluKernel<T>>(n));
216+
} else if (type == "tanh") {
217+
return std::dynamic_pointer_cast<const VActKernel<T>>(
218+
KernelPool::Instance().template Get<VTanhKernel<T>>(n));
219+
} else if (type == "identity" || type == "") {
220+
return std::dynamic_pointer_cast<const VActKernel<T>>(
221+
KernelPool::Instance().template Get<VIdentityKernel<T>>(n));
222+
}
223+
PADDLE_THROW("Not support type: %s", type);
224+
};
225+
act_gate_3d_ = GetActKernel(act_gate, d * 3);
226+
act_cand_d_ = GetActKernel(act_cand, d);
227+
act_cell_d_ = GetActKernel(act_cell, d);
228+
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
229+
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
230+
}
231+
232+
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht,
233+
T* checked) const override {
234+
// gates: W_ch, W_ih, W_fh, W_oh
235+
act_gate_3d_->Compute(gates + d_, gates + d_);
236+
237+
/* C_t = C_t-1 * fgated + cand_gated * igated */
238+
act_cand_d_->Compute(gates, gates);
239+
vmul_d_->Compute(gates, gates + d_, gates + d_);
240+
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_);
241+
vadd_d_->Compute(gates + d_, gates + d2_, ct);
242+
243+
/* H_t = act_cell(C_t) * ogated */
244+
act_cell_d_->Compute(ct, gates + d2_);
245+
vmul_d_->Compute(gates + d2_, gates + d3_, ht);
246+
}
202247

203-
#define JITKERNEL_KEY_LSTM(ker_key, dtype_key) \
204-
#ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell
248+
private:
249+
int d_, d2_, d3_;
250+
std::shared_ptr<const VActKernel<T>> act_gate_3d_, act_cand_d_, act_cell_d_;
251+
std::shared_ptr<const VMulKernel<T>> vmul_d_;
252+
std::shared_ptr<const VAddKernel<T>> vadd_d_;
253+
};
254+
255+
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
256+
template <> \
257+
std::shared_ptr<const LSTMKernel<ker_dtype>> \
258+
KernelPool::Get<LSTMKernel<ker_dtype>, const std::string&, \
259+
const std::string&, const std::string&, int, bool>( \
260+
const std::string& act_gate, const std::string& act_cand, \
261+
const std::string& act_cell, int d, bool use_peephole)
205262

206-
#define JITKERNEL_NEW_LSTM_IMPL(ker, dtype, isa, k) \
207-
p = std::dynamic_pointer_cast<ker<dtype>>( \
208-
std::make_shared<ker##Impl<dtype, isa, k>>(d, act_gate, act_cand, \
209-
act_cell))
263+
#define JITKERNEL_KEY_LSTM(ker_key, dtype_key) \
264+
#ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell + \
265+
(use_peephole ? "p" : "n")
266+
267+
#define JITKERNEL_NEW_LSTM_IMPL(ker, dtype, isa, k) \
268+
if (use_peephole) { \
269+
p = std::dynamic_pointer_cast<ker<dtype>>( \
270+
std::make_shared<PeepholeKernelImpl<dtype, isa, k>>( \
271+
act_gate, act_cand, act_cell, d)); \
272+
} else { \
273+
p = std::dynamic_pointer_cast<ker<dtype>>( \
274+
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_cand, \
275+
act_cell, d)); \
276+
}
210277

211278
REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
212279
JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL);
@@ -215,7 +282,6 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
215282
#undef JITKERNEL_DECLARE_LSTM
216283
#undef JITKERNEL_KEY_LSTM
217284
#undef JITKERNEL_NEW_LSTM_IMPL
218-
219285
} // namespace jitkernel
220286
} // namespace math
221287
} // namespace operators

paddle/fluid/operators/math/jit_kernel_test.cc

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -390,9 +390,9 @@ TEST(JitKernel, lstm) {
390390
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
391391
const auto& ker =
392392
jit::KernelPool::Instance()
393-
.template Get<jit::LSTMKernel<float>, int, const std::string&,
393+
.template Get<jit::LSTMKernel<float>, const std::string&,
394394
const std::string&, const std::string&>(
395-
d, act_gate, act_cand, act_cell);
395+
act_gate, act_cand, act_cell, d, false);
396396
// below kernels are used to compute refer
397397
const auto& vsigmoid_3d =
398398
jit::KernelPool::Instance().template Get<jit::VSigmoidKernel<float>>(
@@ -717,15 +717,20 @@ TEST(JitKernel, pool) {
717717
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
718718
const auto& plstm1 =
719719
jit::KernelPool::Instance()
720-
.template Get<jit::LSTMKernel<float>, int, const std::string&,
720+
.template Get<jit::LSTMKernel<float>, const std::string&,
721721
const std::string&, const std::string&>(
722-
frame_size, act_gate, act_cand, act_cell);
722+
act_gate, act_cand, act_cell, frame_size, false);
723723
const auto& plstm2 =
724724
jit::KernelPool::Instance()
725-
.template Get<jit::LSTMKernel<float>, int, const std::string&,
725+
.template Get<jit::LSTMKernel<float>, const std::string&,
726726
const std::string&, const std::string&>(
727-
frame_size, act_gate, act_cand, act_cell);
728-
EXPECT_EQ(plstm1, plstm2);
727+
act_gate, act_cand, act_cell, frame_size, false);
728+
const auto& peephole =
729+
jit::KernelPool::Instance()
730+
.template Get<jit::LSTMKernel<float>, const std::string&,
731+
const std::string&, const std::string&>(
732+
act_gate, act_cand, act_cell, frame_size, true);
733+
EXPECT_TRUE(plstm1 != peephole);
729734

730735
const auto& pvmul_f =
731736
jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(4);

0 commit comments

Comments
 (0)