@@ -221,10 +221,14 @@ void LSTMJitCode::generate() {
221
221
reg64_t reg_ptr_ct_1 = r9;
222
222
reg64_t reg_ptr_ct = r10;
223
223
reg64_t reg_ptr_ht = r11;
224
+ reg64_t reg_ptr_wp = r12;
224
225
mov (reg_ptr_gates, ptr[param1 + offsetof (lstm_t , gates)]);
225
226
mov (reg_ptr_ct_1, ptr[param1 + offsetof (lstm_t , ct_1)]);
226
227
mov (reg_ptr_ct, ptr[param1 + offsetof (lstm_t , ct)]);
227
228
mov (reg_ptr_ht, ptr[param1 + offsetof (lstm_t , ht)]);
229
+ if (use_peephole_) {
230
+ mov (reg_ptr_wp, ptr[param1 + offsetof (lstm_t , wp)]);
231
+ }
228
232
229
233
int offset = 0 ;
230
234
int d = num_ * sizeof (float );
@@ -235,13 +239,27 @@ void LSTMJitCode::generate() {
235
239
act<ymm_t >(ymm_c, ymm_src, act_cand_);
236
240
// i
237
241
vmovups (ymm_src, ptr[reg_ptr_gates + offset + d]);
242
+ if (!compute_c1h1_ && use_peephole_) {
243
+ ymm_t ymm_wp = ymm_t (2 );
244
+ ymm_t ymm_ct_1 = ymm_t (3 );
245
+ vmovups (ymm_wp, ptr[reg_ptr_wp + offset]);
246
+ vmovups (ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
247
+ vmulps (ymm_wp, ymm_ct_1, ymm_wp);
248
+ vaddps (ymm_src, ymm_src, ymm_wp);
249
+ }
238
250
act<ymm_t >(ymm_i, ymm_src, act_gate_);
239
251
vmulps (ymm_c, ymm_c, ymm_i);
240
252
if (!compute_c1h1_) {
241
253
// f
242
254
vmovups (ymm_src, ptr[reg_ptr_gates + offset + 2 * d]);
243
- act<ymm_t >(ymm_f, ymm_src, act_gate_);
244
255
vmovups (ymm_i, ptr[reg_ptr_ct_1 + offset]);
256
+ if (use_peephole_) {
257
+ ymm_t ymm_wp = ymm_t (3 );
258
+ vmovups (ymm_wp, ptr[reg_ptr_wp + offset + d]);
259
+ vmulps (ymm_wp, ymm_i, ymm_wp);
260
+ vaddps (ymm_src, ymm_src, ymm_wp);
261
+ }
262
+ act<ymm_t >(ymm_f, ymm_src, act_gate_);
245
263
vmulps (ymm_f, ymm_f, ymm_i);
246
264
vaddps (ymm_f, ymm_f, ymm_c);
247
265
}
@@ -250,8 +268,14 @@ void LSTMJitCode::generate() {
250
268
ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c;
251
269
ymm_t ymm_tmp = ymm_i;
252
270
vmovups (ptr[reg_ptr_ct + offset], ymm_ct); // save ct
253
- act<ymm_t >(ymm_tmp, ymm_ct, act_cell_);
254
271
vmovups (ymm_src, ptr[reg_ptr_gates + offset + 3 * d]);
272
+ if (use_peephole_) {
273
+ ymm_t ymm_wp = ymm_t (2 );
274
+ vmovups (ymm_wp, ptr[reg_ptr_wp + offset + d * 2 ]);
275
+ vmulps (ymm_wp, ymm_ct, ymm_wp);
276
+ vaddps (ymm_src, ymm_src, ymm_wp);
277
+ }
278
+ act<ymm_t >(ymm_tmp, ymm_ct, act_cell_);
255
279
act<ymm_t >(ymm_o, ymm_src, act_gate_);
256
280
vmulps (ymm_o, ymm_tmp, ymm_o);
257
281
vmovups (ptr[reg_ptr_ht + offset], ymm_o); // save ht
0 commit comments