Skip to content

Commit 0c5ed5f

Browse files
committed
enable peephole jitcode
test=develop
1 parent e3b61cf commit 0c5ed5f

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,14 @@ void LSTMJitCode::generate() {
221221
reg64_t reg_ptr_ct_1 = r9;
222222
reg64_t reg_ptr_ct = r10;
223223
reg64_t reg_ptr_ht = r11;
224+
reg64_t reg_ptr_wp = r12;
224225
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
225226
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
226227
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
227228
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
229+
if (use_peephole_) {
230+
mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]);
231+
}
228232

229233
int offset = 0;
230234
int d = num_ * sizeof(float);
@@ -235,13 +239,27 @@ void LSTMJitCode::generate() {
235239
act<ymm_t>(ymm_c, ymm_src, act_cand_);
236240
// i
237241
vmovups(ymm_src, ptr[reg_ptr_gates + offset + d]);
242+
if (!compute_c1h1_ && use_peephole_) {
243+
ymm_t ymm_wp = ymm_t(2);
244+
ymm_t ymm_ct_1 = ymm_t(3);
245+
vmovups(ymm_wp, ptr[reg_ptr_wp + offset]);
246+
vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
247+
vmulps(ymm_wp, ymm_ct_1, ymm_wp);
248+
vaddps(ymm_src, ymm_src, ymm_wp);
249+
}
238250
act<ymm_t>(ymm_i, ymm_src, act_gate_);
239251
vmulps(ymm_c, ymm_c, ymm_i);
240252
if (!compute_c1h1_) {
241253
// f
242254
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * d]);
243-
act<ymm_t>(ymm_f, ymm_src, act_gate_);
244255
vmovups(ymm_i, ptr[reg_ptr_ct_1 + offset]);
256+
if (use_peephole_) {
257+
ymm_t ymm_wp = ymm_t(3);
258+
vmovups(ymm_wp, ptr[reg_ptr_wp + offset + d]);
259+
vmulps(ymm_wp, ymm_i, ymm_wp);
260+
vaddps(ymm_src, ymm_src, ymm_wp);
261+
}
262+
act<ymm_t>(ymm_f, ymm_src, act_gate_);
245263
vmulps(ymm_f, ymm_f, ymm_i);
246264
vaddps(ymm_f, ymm_f, ymm_c);
247265
}
@@ -250,8 +268,14 @@ void LSTMJitCode::generate() {
250268
ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c;
251269
ymm_t ymm_tmp = ymm_i;
252270
vmovups(ptr[reg_ptr_ct + offset], ymm_ct); // save ct
253-
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
254271
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * d]);
272+
if (use_peephole_) {
273+
ymm_t ymm_wp = ymm_t(2);
274+
vmovups(ymm_wp, ptr[reg_ptr_wp + offset + d * 2]);
275+
vmulps(ymm_wp, ymm_ct, ymm_wp);
276+
vaddps(ymm_src, ymm_src, ymm_wp);
277+
}
278+
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
255279
act<ymm_t>(ymm_o, ymm_src, act_gate_);
256280
vmulps(ymm_o, ymm_tmp, ymm_o);
257281
vmovups(ptr[reg_ptr_ht + offset], ymm_o); // save ht

paddle/fluid/operators/math/jit_kernel_rnn.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
108108
#ifdef PADDLE_WITH_XBYAK
109109
template <>
110110
bool PeepholeKernelImpl<float>::useJIT(int d) {
111-
return false; // peephole jitcode not ready yet
111+
return gen::LSTMJitCode::init(d);
112112
}
113113
#endif
114114

0 commit comments

Comments
 (0)