@@ -140,32 +140,10 @@ bool VActJitCode::init(int d, operand_type type) {
140
140
}
141
141
142
142
void VActJitCode::generate () {
143
- xmm_t xmm_zero = xmm_t (2 );
144
- ymm_t ymm_zero = ymm_t (2 );
145
- if (type_ == operand_type::relu) {
146
- vxorps (ymm_zero, ymm_zero, ymm_zero);
147
- }
148
143
int offset = 0 ;
149
144
for (int i = 0 ; i < num_ / YMM_FLOAT_BLOCK; ++i) {
150
145
vmovups (ymm_src, ptr[param1 + offset]);
151
- switch (type_) {
152
- case operand_type::relu:
153
- relu_jmm<ymm_t >(ymm_dst, ymm_src, ymm_zero);
154
- break ;
155
- case operand_type::exp:
156
- exp_jmm<ymm_t >(ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
157
- break ;
158
- case operand_type::sigmoid:
159
- sigmoid_jmm<ymm_t >(ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
160
- break ;
161
- case operand_type::tanh:
162
- tanh_jmm<ymm_t >(ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
163
- break ;
164
- case operand_type::identity:
165
- break ;
166
- default :
167
- break ;
168
- }
146
+ act<ymm_t >(ymm_dst, ymm_src, type_);
169
147
vmovups (ptr[param2 + offset], ymm_dst);
170
148
offset += sizeof (float ) * YMM_FLOAT_BLOCK;
171
149
}
@@ -182,22 +160,7 @@ void VActJitCode::generate() {
182
160
block = 1 ;
183
161
vmovss (xmm_src, ptr[param1 + offset]);
184
162
}
185
- switch (type_) {
186
- case operand_type::relu:
187
- relu_jmm<xmm_t >(xmm_dst, xmm_src, xmm_zero);
188
- break ;
189
- case operand_type::exp:
190
- exp_jmm<xmm_t >(xmm_dst, xmm_src, 2 , 3 , 4 , 5 );
191
- break ;
192
- case operand_type::sigmoid:
193
- sigmoid_jmm<xmm_t >(xmm_dst, xmm_src, 2 , 3 , 4 , 5 );
194
- break ;
195
- case operand_type::tanh:
196
- tanh_jmm<xmm_t >(xmm_dst, xmm_src, 2 , 3 , 4 , 5 );
197
- break ;
198
- default :
199
- break ;
200
- }
163
+ act<xmm_t >(xmm_dst, xmm_src, type_);
201
164
if (rest >= 4 ) {
202
165
vmovups (ptr[param2 + offset], xmm_dst);
203
166
} else if (rest >= 2 ) {
@@ -233,52 +196,64 @@ void LSTMJitCode::generate() {
233
196
int offset = 0 ;
234
197
int d = num_ * sizeof (float );
235
198
for (int i = 0 ; i < num_ / YMM_FLOAT_BLOCK; ++i) {
236
- /* C_t = C_t-1 * fgated + cand_gated * igated*/
237
- // c
238
- vmovups (ymm_src, ptr[reg_ptr_gates + offset]);
239
- act<ymm_t >(ymm_c, ymm_src, act_cand_);
240
- // i
241
- vmovups (ymm_src, ptr[reg_ptr_gates + offset + d]);
242
- if (!compute_c1h1_ && use_peephole_) {
243
- ymm_t ymm_wp = ymm_t (2 );
244
- ymm_t ymm_ct_1 = ymm_t (3 );
245
- vmovups (ymm_wp, ptr[reg_ptr_wp + offset]);
199
+ /* gates: W_ch, W_ih, W_fh, W_oh */
200
+ ymm_t ymm_c = ymm_t (0 );
201
+ ymm_t ymm_i = ymm_t (1 );
202
+ ymm_t ymm_f = ymm_t (2 );
203
+ ymm_t ymm_o = ymm_t (3 );
204
+ ymm_t ymm_ct_1 = ymm_t (4 );
205
+ ymm_t ymm_wp0 = ymm_t (5 );
206
+ ymm_t ymm_wp1 = ymm_t (6 );
207
+ ymm_t ymm_wp2 = ymm_t (7 );
208
+ vmovups (ymm_c, ptr[reg_ptr_gates + offset]);
209
+ vmovups (ymm_i, ptr[reg_ptr_gates + offset + d]);
210
+ vmovups (ymm_f, ptr[reg_ptr_gates + offset + 2 * d]);
211
+ vmovups (ymm_o, ptr[reg_ptr_gates + offset + 3 * d]);
212
+ if (!compute_c1h1_) {
246
213
vmovups (ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
247
- vmulps (ymm_wp, ymm_ct_1, ymm_wp);
248
- vaddps (ymm_src, ymm_src, ymm_wp);
249
214
}
250
- act<ymm_t >(ymm_i, ymm_src, act_gate_);
215
+ if (use_peephole_) {
216
+ vmovups (ymm_wp0, ptr[reg_ptr_wp + offset]);
217
+ vmovups (ymm_wp1, ptr[reg_ptr_wp + offset + d]);
218
+ vmovups (ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]);
219
+ }
220
+ /* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
221
+ // act_cand(c)
222
+ act<ymm_t >(ymm_c, ymm_c, act_cand_);
223
+ // act_gate(i) or act_gate(ct_1 * wp0 + i)
224
+ if (!compute_c1h1_ && use_peephole_) {
225
+ vmulps (ymm_wp0, ymm_ct_1, ymm_wp0);
226
+ vaddps (ymm_i, ymm_i, ymm_wp0);
227
+ }
228
+ act<ymm_t >(ymm_i, ymm_i, act_gate_);
251
229
vmulps (ymm_c, ymm_c, ymm_i);
252
230
if (!compute_c1h1_) {
253
- // f
254
- vmovups (ymm_src, ptr[reg_ptr_gates + offset + 2 * d]);
255
- vmovups (ymm_i, ptr[reg_ptr_ct_1 + offset]);
231
+ // act_gate(f) or act_gate(ct_1 * wp1 + f)
256
232
if (use_peephole_) {
257
- ymm_t ymm_wp = ymm_t (3 );
258
- vmovups (ymm_wp, ptr[reg_ptr_wp + offset + d]);
259
- vmulps (ymm_wp, ymm_i, ymm_wp);
260
- vaddps (ymm_src, ymm_src, ymm_wp);
233
+ vmulps (ymm_wp1, ymm_ct_1, ymm_wp1);
234
+ vaddps (ymm_f, ymm_f, ymm_wp1);
261
235
}
262
- act<ymm_t >(ymm_f, ymm_src, act_gate_);
263
- vmulps (ymm_f, ymm_f, ymm_i);
236
+ act<ymm_t >(ymm_f, ymm_f, act_gate_);
237
+ // ct
238
+ vmulps (ymm_f, ymm_f, ymm_ct_1);
264
239
vaddps (ymm_f, ymm_f, ymm_c);
265
240
}
266
- /* H_t = act_cell(C_t) * ogated */
241
+ /* H_t = act_cell(C_t) * act_gate(o) */
242
+ // act_cell(C_t)
267
243
ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
268
- ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c;
269
244
ymm_t ymm_tmp = ymm_i;
270
- vmovups (ptr[reg_ptr_ct + offset] , ymm_ct); // save ct
271
- vmovups (ymm_src, ptr[reg_ptr_gates + offset + 3 * d]);
245
+ act< ymm_t >(ymm_tmp , ymm_ct, act_cell_);
246
+ // act_gate(o) or act_gate(ct * wp2 + o)
272
247
if (use_peephole_) {
273
- ymm_t ymm_wp = ymm_t (2 );
274
- vmovups (ymm_wp, ptr[reg_ptr_wp + offset + d * 2 ]);
275
- vmulps (ymm_wp, ymm_ct, ymm_wp);
276
- vaddps (ymm_src, ymm_src, ymm_wp);
248
+ vmulps (ymm_wp2, ymm_ct, ymm_wp2);
249
+ vaddps (ymm_o, ymm_o, ymm_wp2);
277
250
}
278
- act<ymm_t >(ymm_tmp, ymm_ct, act_cell_);
279
- act<ymm_t >(ymm_o, ymm_src, act_gate_);
280
- vmulps (ymm_o, ymm_tmp, ymm_o);
281
- vmovups (ptr[reg_ptr_ht + offset], ymm_o); // save ht
251
+ act<ymm_t >(ymm_o, ymm_o, act_gate_);
252
+ // ht
253
+ vmulps (ymm_o, ymm_o, ymm_tmp);
254
+ // save ct and ht
255
+ vmovups (ptr[reg_ptr_ct + offset], ymm_ct);
256
+ vmovups (ptr[reg_ptr_ht + offset], ymm_o);
282
257
offset += sizeof (float ) * YMM_FLOAT_BLOCK;
283
258
}
284
259
@@ -293,13 +268,61 @@ bool GRUJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
293
268
294
269
void GRUJitCode::generate () {
295
270
reg64_t reg_ptr_gates = rax;
296
- reg64_t reg_ptr_ct_1 = r9;
297
- reg64_t reg_ptr_ct = r10;
298
- reg64_t reg_ptr_ht = r11;
299
- mov (reg_ptr_gates, ptr[param1 + offsetof (lstm_t , gates)]);
300
- mov (reg_ptr_ct_1, ptr[param1 + offsetof (lstm_t , ct_1)]);
301
- mov (reg_ptr_ct, ptr[param1 + offsetof (lstm_t , ct)]);
302
- mov (reg_ptr_ht, ptr[param1 + offsetof (lstm_t , ht)]);
271
+ reg64_t reg_ptr_ht_1 = r9;
272
+ reg64_t reg_ptr_ht = r10;
273
+ mov (reg_ptr_gates, ptr[param1 + offsetof (gru_t , gates)]);
274
+ mov (reg_ptr_ht_1, ptr[param1 + offsetof (gru_t , ht_1)]);
275
+ mov (reg_ptr_ht, ptr[param1 + offsetof (gru_t , ht)]);
276
+ ymm_t ymm_one = ymm_t (0 );
277
+
278
+ if (id_ == 2 ) {
279
+ reg64_t reg_ptr_tmp = r11;
280
+ mov (reg_ptr_tmp, reinterpret_cast <size_t >(exp_float_consts));
281
+ vmovaps (ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]);
282
+ }
283
+ int offset = 0 ;
284
+ int d = num_ * sizeof (float );
285
+ for (int i = 0 ; i < num_ / YMM_FLOAT_BLOCK; ++i) {
286
+ ymm_t ymm_u = ymm_t (1 );
287
+ ymm_t ymm_r = ymm_t (2 );
288
+ ymm_t ymm_s = ymm_t (3 );
289
+ ymm_t ymm_ht_1 = ymm_t (4 );
290
+ // W: {W_update, W_reset; W_state}
291
+ if (id_ == 0 || id_ == 2 ) {
292
+ vmovups (ymm_u, ptr[reg_ptr_gates + offset]);
293
+ vmovups (ymm_s, ptr[reg_ptr_gates + offset + 2 * d]);
294
+ }
295
+ if (id_ == 1 ) {
296
+ vmovups (ymm_r, ptr[reg_ptr_gates + offset + d]);
297
+ }
298
+ if (id_ == 1 || id_ == 2 ) {
299
+ vmovups (ymm_ht_1, ptr[reg_ptr_ht_1 + offset]);
300
+ }
301
+
302
+ if (id_ == 0 ) {
303
+ // ht = act_gate(u) * act_cand(s)
304
+ act<ymm_t >(ymm_u, ymm_u, act_gate_);
305
+ act<ymm_t >(ymm_s, ymm_s, act_cand_);
306
+ vmulps (ymm_s, ymm_s, ymm_u);
307
+ vmovups (ptr[reg_ptr_ht + offset], ymm_s);
308
+ } else if (id_ == 1 ) {
309
+ // ht = act_gate(r) * ht_1
310
+ act<ymm_t >(ymm_r, ymm_r, act_gate_);
311
+ vmulps (ymm_r, ymm_r, ymm_ht_1);
312
+ vmovups (ptr[reg_ptr_ht + offset], ymm_r);
313
+ } else if (id_ == 2 ) {
314
+ // ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
315
+ ymm_t ymm_one_inner = ymm_t (ymm_one.getIdx ());
316
+ act<ymm_t >(ymm_u, ymm_u, act_gate_);
317
+ act<ymm_t >(ymm_s, ymm_s, act_cand_);
318
+ vmulps (ymm_s, ymm_s, ymm_u);
319
+ vsubps (ymm_u, ymm_one_inner, ymm_u);
320
+ vmulps (ymm_u, ymm_ht_1, ymm_u);
321
+ vaddps (ymm_u, ymm_s, ymm_u);
322
+ vmovups (ptr[reg_ptr_ht + offset], ymm_u);
323
+ }
324
+ offset += sizeof (float ) * YMM_FLOAT_BLOCK;
325
+ }
303
326
304
327
ret ();
305
328
}
0 commit comments