Skip to content

Commit f913860

Browse files
committed
jitkernel lstm refer support peephole
test=develop
1 parent 2f9b5f2 commit f913860

File tree

9 files changed

+250
-263
lines changed

9 files changed

+250
-263
lines changed

paddle/fluid/operators/fused/fusion_lstm_op.cc

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -236,27 +236,31 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
236236
const int D = wh_dims[0]; \
237237
const int D4 = wh_dims[1]
238238

239-
#define INIT_OTHER_DEFINES \
240-
const T* x_data = x->data<T>(); \
241-
const T* wx_data = wx->data<T>(); \
242-
const T* wh_data = wh->data<T>(); \
243-
/* diagonal weight*/ \
244-
const T* wp_data = bias->data<T>() + D4; \
245-
/* for peephole only*/ \
246-
T* checked_cell_data = nullptr; \
247-
auto place = ctx.GetPlace(); \
248-
if (use_peepholes) { \
249-
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
250-
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
251-
checked_cell_data = checked_cell->mutable_data<T>(place); \
252-
} \
253-
const auto& ker = \
254-
math::jitkernel::KernelPool::Instance() \
255-
.template Get<math::jitkernel::LSTMKernel<T>, const std::string&, \
256-
const std::string&, const std::string&>( \
257-
ctx.Attr<std::string>("gate_activation"), \
258-
ctx.Attr<std::string>("candidate_activation"), \
259-
ctx.Attr<std::string>("cell_activation"), D, use_peepholes)
239+
#define INIT_OTHER_DEFINES \
240+
const T* x_data = x->data<T>(); \
241+
const T* wx_data = wx->data<T>(); \
242+
const T* wh_data = wh->data<T>(); \
243+
/* diagonal weight*/ \
244+
const T* wp_data = bias->data<T>() + D4; \
245+
/* for peephole only*/ \
246+
T* checked_cell_data = nullptr; \
247+
auto place = ctx.GetPlace(); \
248+
if (use_peepholes) { \
249+
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
250+
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
251+
checked_cell_data = checked_cell->mutable_data<T>(place); \
252+
} \
253+
const math::jitkernel::lstm_attr_t attr( \
254+
D, ctx.Attr<std::string>("gate_activation"), \
255+
ctx.Attr<std::string>("candidate_activation"), \
256+
ctx.Attr<std::string>("cell_activation"), use_peepholes); \
257+
math::jitkernel::lstm_t one_step; \
258+
one_step.wp = wp_data; \
259+
one_step.checked = checked_cell_data; \
260+
const auto& ker = \
261+
math::jitkernel::KernelPool::Instance() \
262+
.template Get<math::jitkernel::LSTMKernel<T>, \
263+
const math::jitkernel::lstm_attr_t&>(attr)
260264

261265
// Wh GEMM
262266
#define GEMM_WH_ADDON(bs, prev, out) \
@@ -299,7 +303,10 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
299303
prev_h_data = h0_data + bid * D;
300304
prev_c_data = c0_data + bid * D;
301305
} else {
302-
ker->ComputeC1H1(xx_data, c_out_data, h_out_data, wp_data);
306+
one_step.gates = xx_data;
307+
one_step.ct = c_out_data;
308+
one_step.ht = h_out_data;
309+
ker->ComputeC1H1(&one_step, &attr);
303310
tstart = 1;
304311
// move one step
305312
prev_h_data = h_out_data;
@@ -310,8 +317,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
310317
}
311318
for (int step = tstart; step < seq_len; ++step) {
312319
GEMM_WH_ADDON(1, prev_h_data, xx_data);
313-
ker->ComputeCtHt(xx_data, prev_c_data, c_out_data, h_out_data, wp_data,
314-
checked_cell_data);
320+
321+
one_step.gates = xx_data;
322+
one_step.ct_1 = prev_c_data;
323+
one_step.ct = c_out_data;
324+
one_step.ht = h_out_data;
325+
ker->ComputeCtHt(&one_step, &attr);
315326
// move one step
316327
prev_h_data = h_out_data;
317328
prev_c_data = c_out_data;
@@ -388,7 +399,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
388399
T* cur_h_out_data = batched_h_out_data;
389400
T* cur_c_out_data = batched_c_out_data;
390401
for (int i = 0; i < max_bs; ++i) {
391-
ker->ComputeC1H1(cur_in_data, cur_c_out_data, cur_h_out_data, wp_data);
402+
one_step.gates = cur_in_data;
403+
one_step.ct = cur_c_out_data;
404+
one_step.ht = cur_h_out_data;
405+
ker->ComputeC1H1(&one_step, &attr);
406+
392407
cur_in_data += D4;
393408
cur_c_out_data += D;
394409
cur_h_out_data += D;
@@ -413,8 +428,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
413428
T* cur_c_out_data = batched_c_out_data;
414429
T* cur_h_out_data = batched_h_out_data;
415430
for (int i = 0; i < cur_bs; ++i) {
416-
ker->ComputeCtHt(cur_in_data, cur_prev_c_data, cur_c_out_data,
417-
cur_h_out_data, wp_data, checked_cell_data);
431+
one_step.gates = cur_in_data;
432+
one_step.ct_1 = cur_prev_c_data;
433+
one_step.ct = cur_c_out_data;
434+
one_step.ht = cur_h_out_data;
435+
ker->ComputeCtHt(&one_step, &attr);
436+
418437
// move one batch
419438
cur_in_data += D4;
420439
cur_prev_c_data += D;

paddle/fluid/operators/math/jit_code.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ void LSTMJitCode::generate() {
233233
vmovups(ymm_src, ptr[reg_ptr_gates + offset + num_]);
234234
act<ymm_t>(ymm_i, ymm_src, act_gate_);
235235
vmulps(ymm_c, ymm_c, ymm_i);
236-
if (first_) {
236+
if (!compute_c1h1_) {
237237
// f
238238
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * num_]);
239239
act<ymm_t>(ymm_f, ymm_src, act_gate_);
@@ -242,8 +242,8 @@ void LSTMJitCode::generate() {
242242
vaddps(ymm_f, ymm_f, ymm_c);
243243
}
244244
/* H_t = act_cell(C_t) * ogated */
245-
ymm_t ymm_ct = first_ ? ymm_c : ymm_f;
246-
ymm_t ymm_o = first_ ? ymm_f : ymm_c;
245+
ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
246+
ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c;
247247
ymm_t ymm_tmp = ymm_i;
248248
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
249249
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * num_]);

paddle/fluid/operators/math/jit_code.h

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,12 @@ class LSTMJitCode : public VActJitCode {
319319
public:
320320
const char* name() const override {
321321
std::string base = "LSTMJitCode";
322+
if (use_peephole_) {
323+
base += "_Peephole";
324+
}
325+
if (compute_c1h1_) {
326+
base += "_C1H1";
327+
}
322328
auto AddTypeStr = [&](operand_type type) {
323329
switch (type) {
324330
case operand_type::relu:
@@ -340,30 +346,42 @@ class LSTMJitCode : public VActJitCode {
340346
break;
341347
}
342348
};
343-
if (first_) {
344-
base += "_C1H1";
345-
}
346349
AddTypeStr(act_gate_);
347350
AddTypeStr(act_cand_);
348351
AddTypeStr(act_cell_);
349352
return base.c_str();
350353
}
351354

352-
explicit LSTMJitCode(int d, bool first, operand_type act_gate,
353-
operand_type act_cand, operand_type act_cell,
355+
explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr,
354356
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
355-
: VActJitCode(d, act_gate, code_size, code_ptr),
356-
num_(d),
357-
first_(first),
358-
act_gate_(act_gate),
359-
act_cand_(act_cand),
360-
act_cell_(act_cell) {}
357+
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
358+
code_ptr),
359+
compute_c1h1_(compute_c1h1) {
360+
auto typeExchange = [](const std::string& type) -> gen::operand_type {
361+
if (type == "sigmoid") {
362+
return operand_type::sigmoid;
363+
} else if (type == "relu") {
364+
return operand_type::relu;
365+
} else if (type == "tanh") {
366+
return operand_type::tanh;
367+
} else if (type == "identity" || type == "") {
368+
return operand_type::identity;
369+
} // else throw error
370+
return operand_type::identity;
371+
};
372+
num_ = attr.d;
373+
use_peephole_ = attr.use_peephole;
374+
act_gate_ = typeExchange(attr.act_gate);
375+
act_cand_ = typeExchange(attr.act_cand);
376+
act_cell_ = typeExchange(attr.act_cell);
377+
}
361378
static bool init(int d);
362379
void generate() override;
363380

364381
protected:
365382
int num_;
366-
bool first_;
383+
bool compute_c1h1_;
384+
bool use_peephole_;
367385
operand_type act_gate_;
368386
operand_type act_cand_;
369387
operand_type act_cell_;

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,9 @@ class VTanhKernel : public VActKernel<T> {};
122122
template <typename T>
123123
class LSTMKernel : public Kernel {
124124
public:
125-
virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht,
126-
/* below only used in peephole*/
127-
const T *wp_data = nullptr,
128-
T *checked = nullptr) const = 0;
129-
130-
virtual void ComputeC1H1(T *gates, T *ct, T *ht,
131-
/* below only used in peephole*/
132-
const T *wp_data = nullptr) const = 0;
133-
134-
// void (*ComputeCtHt)(lstm_t *);
135-
// // compute c1 and h1 without c0 or h0
136-
// void (*ComputeC1H1)(lstm_t *);
125+
void (*ComputeCtHt)(lstm_t *, const lstm_attr_t *);
126+
// compute c1 and h1 without c0 or h0
127+
void (*ComputeC1H1)(lstm_t *, const lstm_attr_t *);
137128
};
138129

139130
template <typename T>

paddle/fluid/operators/math/jit_kernel_impl.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,24 @@ typedef struct {
3333
const void* ct_1;
3434
void* ct;
3535
void* ht;
36-
/* below only used in peephole*/
37-
const void* wp_data{nullptr};
36+
/* weight_peephole and checked data are only used in peephole*/
37+
const void* wp{nullptr};
3838
void* checked{nullptr};
3939
} lstm_t;
4040

4141
typedef struct lstm_attr_s {
42+
bool use_peephole;
4243
int d;
4344
std::string act_gate, act_cand, act_cell;
4445
lstm_attr_s() = default;
4546
lstm_attr_s(int _d, const std::string& _act_gate,
46-
const std::string& _act_cand, const std::string& _act_cell)
47-
: d(_d), act_gate(_act_gate), act_cand(_act_cand), act_cell(_act_cell) {}
47+
const std::string& _act_cand, const std::string& _act_cell,
48+
bool _use_peephole = false)
49+
: use_peephole(_use_peephole),
50+
d(_d),
51+
act_gate(_act_gate),
52+
act_cand(_act_cand),
53+
act_cell(_act_cell) {}
4854
} lstm_attr_t;
4955

5056
} // namespace jitkernel

paddle/fluid/operators/math/jit_kernel_macro.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ namespace jitkernel {
8282
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \
8383
marco_declare, macro_find_key, macro_impl) \
8484
marco_define_name(ker_key, ker_class); \
85-
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, JITKERNEL_DECLARE, \
86-
JITKERNEL_FIND_KEY, JITKERNEL_IMPL); \
87-
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, JITKERNEL_DECLARE, \
88-
JITKERNEL_FIND_KEY, JITKERNEL_IMPL)
85+
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, marco_declare, \
86+
macro_find_key, macro_impl); \
87+
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, marco_declare, \
88+
macro_find_key, macro_impl)
8989

9090
#define REGISTER_JITKERNEL(ker_key, ker_class) \
9191
REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \

paddle/fluid/operators/math/jit_kernel_refer.h

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,35 +117,50 @@ void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT
117117
}
118118

119119
template <typename T>
120-
void LSTMCtHt(lstm_t* step, lstm_attr_t* attr) {
120+
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
121121
T* gates = reinterpret_cast<T*>(step->gates);
122122
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
123123
T* ct = reinterpret_cast<T*>(step->ct);
124124
T* ht = reinterpret_cast<T*>(step->ht);
125+
const T* wp = reinterpret_cast<const T*>(step->wp);
126+
T* checked = reinterpret_cast<T*>(step->checked);
125127
auto act_gate = getActFunc<T>(attr->act_gate);
126128
auto act_cand = getActFunc<T>(attr->act_cand);
127129
auto act_cell = getActFunc<T>(attr->act_cell);
128130
int d = attr->d;
129131
int d2 = d * 2;
130132
int d3 = d * 3;
131133
// gates: W_ch, W_ih, W_fh, W_oh
132-
act_gate(gates + d, gates + d, d3);
134+
if (attr->use_peephole) {
135+
VMul(wp, ct_1, checked, d);
136+
VMul(wp + d, ct_1, checked + d, d);
137+
VAdd(checked, gates + d, gates + d, d2);
138+
act_gate(gates + d, gates + d, d2);
139+
} else {
140+
act_gate(gates + d, gates + d, d3);
141+
}
133142

134-
/* C_t = C_t-1 * fgated + cand_gated * igated */
143+
// C_t = C_t-1 * fgated + cand_gated * igated
135144
act_cand(gates, gates, d);
136145
VMul(gates, gates + d, gates + d, d);
137146
VMul(ct_1, gates + d2, gates + d2, d);
138147
VAdd(gates + d, gates + d2, ct, d);
139148

140-
/* H_t = act_cell(C_t) * ogated */
149+
if (attr->use_peephole) {
150+
// get ogated
151+
VMul(wp + d2, ct, gates + d, d);
152+
VAdd(gates + d, gates + d3, gates + d3, d);
153+
act_gate(gates + d3, gates + d3, d);
154+
}
155+
// H_t = act_cell(C_t) * ogated
141156
act_cell(ct, gates + d2, d);
142157
VMul(gates + d2, gates + d3, ht, d);
143158
}
144159

160+
// compute c1 and h1 without c0 or h0
145161
template <typename T>
146-
void LSTMC1H1(lstm_t* step, lstm_attr_t* attr) {
162+
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
147163
T* gates = reinterpret_cast<T*>(step->gates);
148-
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
149164
T* ct = reinterpret_cast<T*>(step->ct);
150165
T* ht = reinterpret_cast<T*>(step->ht);
151166
auto act_gate = getActFunc<T>(attr->act_gate);
@@ -158,10 +173,16 @@ void LSTMC1H1(lstm_t* step, lstm_attr_t* attr) {
158173
act_gate(gates + d, gates + d, d);
159174
act_cand(gates, gates, d);
160175
VMul(gates, gates + d, ct, d);
176+
if (attr->use_peephole) {
177+
// get outgated, put W_oc * C_t on igated
178+
const T* wp = reinterpret_cast<const T*>(step->wp);
179+
VMul(wp + d2, ct, gates + d, d);
180+
VAdd(gates + d, gates + d3, gates + d3, d);
181+
}
161182
/* H_t = act_cell(C_t) * ogated */
162183
act_gate(gates + d3, gates + d3, d);
163184
act_cell(ct, gates + d2, d);
164-
Vmul(gates + d2, gates + d3, ht, d);
185+
VMul(gates + d2, gates + d3, ht, d);
165186
}
166187

167188
} // namespace refer

0 commit comments

Comments
 (0)