Skip to content

Commit 6a7f83d

Browse files
committed
enable gru jitcode and refine act and lstm jitcode
test=develop
1 parent 686eaf2 commit 6a7f83d

File tree

5 files changed

+149
-136
lines changed

5 files changed

+149
-136
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 103 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -140,32 +140,10 @@ bool VActJitCode::init(int d, operand_type type) {
140140
}
141141

142142
void VActJitCode::generate() {
143-
xmm_t xmm_zero = xmm_t(2);
144-
ymm_t ymm_zero = ymm_t(2);
145-
if (type_ == operand_type::relu) {
146-
vxorps(ymm_zero, ymm_zero, ymm_zero);
147-
}
148143
int offset = 0;
149144
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
150145
vmovups(ymm_src, ptr[param1 + offset]);
151-
switch (type_) {
152-
case operand_type::relu:
153-
relu_jmm<ymm_t>(ymm_dst, ymm_src, ymm_zero);
154-
break;
155-
case operand_type::exp:
156-
exp_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
157-
break;
158-
case operand_type::sigmoid:
159-
sigmoid_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
160-
break;
161-
case operand_type::tanh:
162-
tanh_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
163-
break;
164-
case operand_type::identity:
165-
break;
166-
default:
167-
break;
168-
}
146+
act<ymm_t>(ymm_dst, ymm_src, type_);
169147
vmovups(ptr[param2 + offset], ymm_dst);
170148
offset += sizeof(float) * YMM_FLOAT_BLOCK;
171149
}
@@ -182,22 +160,7 @@ void VActJitCode::generate() {
182160
block = 1;
183161
vmovss(xmm_src, ptr[param1 + offset]);
184162
}
185-
switch (type_) {
186-
case operand_type::relu:
187-
relu_jmm<xmm_t>(xmm_dst, xmm_src, xmm_zero);
188-
break;
189-
case operand_type::exp:
190-
exp_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
191-
break;
192-
case operand_type::sigmoid:
193-
sigmoid_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
194-
break;
195-
case operand_type::tanh:
196-
tanh_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
197-
break;
198-
default:
199-
break;
200-
}
163+
act<xmm_t>(xmm_dst, xmm_src, type_);
201164
if (rest >= 4) {
202165
vmovups(ptr[param2 + offset], xmm_dst);
203166
} else if (rest >= 2) {
@@ -233,52 +196,64 @@ void LSTMJitCode::generate() {
233196
int offset = 0;
234197
int d = num_ * sizeof(float);
235198
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
236-
/* C_t = C_t-1 * fgated + cand_gated * igated*/
237-
// c
238-
vmovups(ymm_src, ptr[reg_ptr_gates + offset]);
239-
act<ymm_t>(ymm_c, ymm_src, act_cand_);
240-
// i
241-
vmovups(ymm_src, ptr[reg_ptr_gates + offset + d]);
242-
if (!compute_c1h1_ && use_peephole_) {
243-
ymm_t ymm_wp = ymm_t(2);
244-
ymm_t ymm_ct_1 = ymm_t(3);
245-
vmovups(ymm_wp, ptr[reg_ptr_wp + offset]);
199+
/* gates: W_ch, W_ih, W_fh, W_oh */
200+
ymm_t ymm_c = ymm_t(0);
201+
ymm_t ymm_i = ymm_t(1);
202+
ymm_t ymm_f = ymm_t(2);
203+
ymm_t ymm_o = ymm_t(3);
204+
ymm_t ymm_ct_1 = ymm_t(4);
205+
ymm_t ymm_wp0 = ymm_t(5);
206+
ymm_t ymm_wp1 = ymm_t(6);
207+
ymm_t ymm_wp2 = ymm_t(7);
208+
vmovups(ymm_c, ptr[reg_ptr_gates + offset]);
209+
vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]);
210+
vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]);
211+
vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]);
212+
if (!compute_c1h1_) {
246213
vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
247-
vmulps(ymm_wp, ymm_ct_1, ymm_wp);
248-
vaddps(ymm_src, ymm_src, ymm_wp);
249214
}
250-
act<ymm_t>(ymm_i, ymm_src, act_gate_);
215+
if (use_peephole_) {
216+
vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]);
217+
vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]);
218+
vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]);
219+
}
220+
/* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
221+
// act_cand(c)
222+
act<ymm_t>(ymm_c, ymm_c, act_cand_);
223+
// act_gate(i) or act_gate(ct_1 * wp0 + i)
224+
if (!compute_c1h1_ && use_peephole_) {
225+
vmulps(ymm_wp0, ymm_ct_1, ymm_wp0);
226+
vaddps(ymm_i, ymm_i, ymm_wp0);
227+
}
228+
act<ymm_t>(ymm_i, ymm_i, act_gate_);
251229
vmulps(ymm_c, ymm_c, ymm_i);
252230
if (!compute_c1h1_) {
253-
// f
254-
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * d]);
255-
vmovups(ymm_i, ptr[reg_ptr_ct_1 + offset]);
231+
// act_gate(f) or act_gate(ct_1 * wp1 + f)
256232
if (use_peephole_) {
257-
ymm_t ymm_wp = ymm_t(3);
258-
vmovups(ymm_wp, ptr[reg_ptr_wp + offset + d]);
259-
vmulps(ymm_wp, ymm_i, ymm_wp);
260-
vaddps(ymm_src, ymm_src, ymm_wp);
233+
vmulps(ymm_wp1, ymm_ct_1, ymm_wp1);
234+
vaddps(ymm_f, ymm_f, ymm_wp1);
261235
}
262-
act<ymm_t>(ymm_f, ymm_src, act_gate_);
263-
vmulps(ymm_f, ymm_f, ymm_i);
236+
act<ymm_t>(ymm_f, ymm_f, act_gate_);
237+
// ct
238+
vmulps(ymm_f, ymm_f, ymm_ct_1);
264239
vaddps(ymm_f, ymm_f, ymm_c);
265240
}
266-
/* H_t = act_cell(C_t) * ogated */
241+
/* H_t = act_cell(C_t) * act_gate(o) */
242+
// act_cell(C_t)
267243
ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
268-
ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c;
269244
ymm_t ymm_tmp = ymm_i;
270-
vmovups(ptr[reg_ptr_ct + offset], ymm_ct); // save ct
271-
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * d]);
245+
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
246+
// act_gate(o) or act_gate(ct * wp2 + o)
272247
if (use_peephole_) {
273-
ymm_t ymm_wp = ymm_t(2);
274-
vmovups(ymm_wp, ptr[reg_ptr_wp + offset + d * 2]);
275-
vmulps(ymm_wp, ymm_ct, ymm_wp);
276-
vaddps(ymm_src, ymm_src, ymm_wp);
248+
vmulps(ymm_wp2, ymm_ct, ymm_wp2);
249+
vaddps(ymm_o, ymm_o, ymm_wp2);
277250
}
278-
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
279-
act<ymm_t>(ymm_o, ymm_src, act_gate_);
280-
vmulps(ymm_o, ymm_tmp, ymm_o);
281-
vmovups(ptr[reg_ptr_ht + offset], ymm_o); // save ht
251+
act<ymm_t>(ymm_o, ymm_o, act_gate_);
252+
// ht
253+
vmulps(ymm_o, ymm_o, ymm_tmp);
254+
// save ct and ht
255+
vmovups(ptr[reg_ptr_ct + offset], ymm_ct);
256+
vmovups(ptr[reg_ptr_ht + offset], ymm_o);
282257
offset += sizeof(float) * YMM_FLOAT_BLOCK;
283258
}
284259

@@ -293,13 +268,61 @@ bool GRUJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
293268

294269
void GRUJitCode::generate() {
295270
reg64_t reg_ptr_gates = rax;
296-
reg64_t reg_ptr_ct_1 = r9;
297-
reg64_t reg_ptr_ct = r10;
298-
reg64_t reg_ptr_ht = r11;
299-
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
300-
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
301-
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
302-
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
271+
reg64_t reg_ptr_ht_1 = r9;
272+
reg64_t reg_ptr_ht = r10;
273+
mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]);
274+
mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]);
275+
mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]);
276+
ymm_t ymm_one = ymm_t(0);
277+
278+
if (id_ == 2) {
279+
reg64_t reg_ptr_tmp = r11;
280+
mov(reg_ptr_tmp, reinterpret_cast<size_t>(exp_float_consts));
281+
vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]);
282+
}
283+
int offset = 0;
284+
int d = num_ * sizeof(float);
285+
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
286+
ymm_t ymm_u = ymm_t(1);
287+
ymm_t ymm_r = ymm_t(2);
288+
ymm_t ymm_s = ymm_t(3);
289+
ymm_t ymm_ht_1 = ymm_t(4);
290+
// W: {W_update, W_reset; W_state}
291+
if (id_ == 0 || id_ == 2) {
292+
vmovups(ymm_u, ptr[reg_ptr_gates + offset]);
293+
vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]);
294+
}
295+
if (id_ == 1) {
296+
vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]);
297+
}
298+
if (id_ == 1 || id_ == 2) {
299+
vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]);
300+
}
301+
302+
if (id_ == 0) {
303+
// ht = act_gate(u) * act_cand(s)
304+
act<ymm_t>(ymm_u, ymm_u, act_gate_);
305+
act<ymm_t>(ymm_s, ymm_s, act_cand_);
306+
vmulps(ymm_s, ymm_s, ymm_u);
307+
vmovups(ptr[reg_ptr_ht + offset], ymm_s);
308+
} else if (id_ == 1) {
309+
// ht = act_gate(r) * ht_1
310+
act<ymm_t>(ymm_r, ymm_r, act_gate_);
311+
vmulps(ymm_r, ymm_r, ymm_ht_1);
312+
vmovups(ptr[reg_ptr_ht + offset], ymm_r);
313+
} else if (id_ == 2) {
314+
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
315+
ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx());
316+
act<ymm_t>(ymm_u, ymm_u, act_gate_);
317+
act<ymm_t>(ymm_s, ymm_s, act_cand_);
318+
vmulps(ymm_s, ymm_s, ymm_u);
319+
vsubps(ymm_u, ymm_one_inner, ymm_u);
320+
vmulps(ymm_u, ymm_ht_1, ymm_u);
321+
vaddps(ymm_u, ymm_s, ymm_u);
322+
vmovups(ptr[reg_ptr_ht + offset], ymm_u);
323+
}
324+
offset += sizeof(float) * YMM_FLOAT_BLOCK;
325+
}
303326

304327
ret();
305328
}

0 commit comments

Comments
 (0)