@@ -81,10 +81,10 @@ void VXXJitCode::generate() {
81
81
}
82
82
if (rest >= 2 ) {
83
83
if (scalar_index_ != 1 ) {
84
- vmovups (xmm_src1, ptr[param1 + offset]);
84
+ vmovq (xmm_src1, ptr[param1 + offset]);
85
85
}
86
86
if (scalar_index_ != 2 ) {
87
- vmovups (xmm_src2, ptr[param2 + offset]);
87
+ vmovq (xmm_src2, ptr[param2 + offset]);
88
88
}
89
89
if (type_ == operand_type::mul) {
90
90
vmulps (xmm_dst, xmm_src1, xmm_src2);
@@ -100,10 +100,10 @@ void VXXJitCode::generate() {
100
100
}
101
101
if (rest > 0 ) {
102
102
if (scalar_index_ != 1 ) {
103
- vmovups (xmm_src1, ptr[param1 + offset]);
103
+ vmovss (xmm_src1, ptr[param1 + offset]);
104
104
}
105
105
if (scalar_index_ != 2 ) {
106
- vmovups (xmm_src2, ptr[param2 + offset]);
106
+ vmovss (xmm_src2, ptr[param2 + offset]);
107
107
}
108
108
if (type_ == operand_type::mul) {
109
109
vmulss (xmm_dst, xmm_src1, xmm_src2);
@@ -179,7 +179,7 @@ bool VActJitCode::init(int d, operand_type type) {
179
179
return ok;
180
180
} else if (type == operand_type::exp) {
181
181
// exp is slower than mkl when d >= 256
182
- return ok && d % 8 == 0 && d < 256 ;
182
+ return ok; // && d % 4 == 0 && d < 256;
183
183
} else {
184
184
// TODO(TJ): support more
185
185
return ok && d % 8 == 0 ;
@@ -190,6 +190,10 @@ void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) {
190
190
vmaxps (ymm_dst, ymm_zero, ymm_src);
191
191
}
192
192
193
+ void VActJitCode::relu_xmm (xmm_t & xmm_dst, xmm_t & xmm_src, xmm_t & xmm_zero) {
194
+ vmaxps (xmm_dst, xmm_zero, xmm_src);
195
+ }
196
+
193
197
void VActJitCode::exp_ymm (ymm_t & ymm_dst, ymm_t & ymm_src, int fx_idx,
194
198
int fy_idx, int mask_idx, int tmp_idx) {
195
199
assert (ymm_src.getIdx () != ymm_dst.getIdx ()); // TODO(TJ): use enfore
@@ -271,6 +275,65 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
271
275
pop (reg_ptr_global);
272
276
}
273
277
278
+ void VActJitCode::exp_xmm (xmm_t & ymm_dst, xmm_t & ymm_src, int fx_idx,
279
+ int fy_idx, int mask_idx, int tmp_idx) {
280
+ assert (ymm_src.getIdx () != ymm_dst.getIdx ()); // TODO(TJ): use enfore
281
+ // check all idx can not equal
282
+ xmm_t ymm_fx = xmm_t (fx_idx);
283
+ xmm_t ymm_fy = xmm_t (fy_idx);
284
+ xmm_t ymm_mask = xmm_t (mask_idx);
285
+ xmm_t ymm_tmp = xmm_t (tmp_idx);
286
+ reg64_t reg_ptr_global = rax;
287
+ push (reg_ptr_global);
288
+ mov (reg_ptr_global, reinterpret_cast <size_t >(exp_float_consts));
289
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
290
+ vminps (ymm_src, ymm_src, ymm_tmp);
291
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
292
+ vmaxps (ymm_src, ymm_src, ymm_tmp);
293
+ // express exp(x) as exp(g + n*log(2))
294
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
295
+ vmulps (ymm_fx, ymm_src, ymm_tmp);
296
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
297
+ vaddps (ymm_fx, ymm_fx, ymm_tmp);
298
+ vroundps (ymm_fy, ymm_fx, 0x01 );
299
+ // if greater, substract 1
300
+ vcmpgtps (ymm_mask, ymm_fy, ymm_fx);
301
+ vmovaps (ymm_tmp, ptr[reg_ptr_global]);
302
+ vandps (ymm_mask, ymm_mask, ymm_tmp);
303
+ vsubps (ymm_fx, ymm_fy, ymm_mask);
304
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
305
+ vmulps (ymm_fy, ymm_fx, ymm_tmp);
306
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
307
+ xmm_t ymm_z = xmm_t (ymm_mask.getIdx ());
308
+ vmulps (ymm_z, ymm_fx, ymm_tmp);
309
+ vsubps (ymm_src, ymm_src, ymm_fy);
310
+ vsubps (ymm_src, ymm_src, ymm_z);
311
+ vmulps (ymm_z, ymm_src, ymm_src);
312
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
313
+ vmulps (ymm_dst, ymm_src, ymm_tmp);
314
+ for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
315
+ i += (YMM_FLOAT_BLOCK * sizeof (float ))) {
316
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
317
+ vaddps (ymm_dst, ymm_dst, ymm_tmp);
318
+ vmulps (ymm_dst, ymm_dst, ymm_src);
319
+ }
320
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
321
+ vaddps (ymm_dst, ymm_dst, ymm_tmp);
322
+ vmulps (ymm_dst, ymm_dst, ymm_z);
323
+ vaddps (ymm_dst, ymm_dst, ymm_src);
324
+ vmovaps (ymm_tmp, ptr[reg_ptr_global]);
325
+ vaddps (ymm_dst, ymm_dst, ymm_tmp);
326
+ // build 2^n
327
+ xmm_t ymm_int = ymm_fx;
328
+ vcvttps2dq (ymm_int, ymm_fx);
329
+ mov (reg_ptr_global, reinterpret_cast <size_t >(exp_int_0x7f));
330
+ vmovdqa (ymm_tmp, ptr[reg_ptr_global]);
331
+ vpaddd (ymm_int, ymm_int, ymm_tmp);
332
+ vpslld (ymm_int, ymm_int, 23 );
333
+ vmulps (ymm_dst, ymm_dst, ymm_int);
334
+ pop (reg_ptr_global);
335
+ }
336
+
274
337
void VActJitCode::sigmoid_ymm (ymm_t & ymm_dst, ymm_t & ymm_src, int fx_idx,
275
338
int fy_idx, int mask_idx, int tmp_idx) {
276
339
// y = 1 / (1 + e^-x)
@@ -343,29 +406,58 @@ void VActJitCode::generate() {
343
406
vmovups (ptr[param2 + offset], ymm_dst);
344
407
offset += sizeof (float ) * YMM_FLOAT_BLOCK;
345
408
}
346
- if (type_ != operand_type::relu) {
409
+ if (type_ != operand_type::relu && type_ != operand_type::exp ) {
347
410
// TODO(TJ): remove me
348
411
ret ();
349
412
return ;
350
413
}
351
414
int rest = num_ % YMM_FLOAT_BLOCK;
352
415
if (rest >= 4 ) {
353
416
vmovups (xmm_src, ptr[param1 + offset]);
354
- vmaxps (xmm_dst, xmm_zero, xmm_src);
417
+ switch (type_) {
418
+ case operand_type::relu:
419
+ relu_xmm (xmm_dst, xmm_src, xmm_zero);
420
+ break ;
421
+ case operand_type::exp:
422
+ exp_xmm (xmm_dst, xmm_src, 2 , 3 , 4 , 5 );
423
+ break ;
424
+ default :
425
+ break ;
426
+ }
355
427
vmovups (ptr[param2 + offset], xmm_dst);
356
428
offset += sizeof (float ) * 4 ;
357
429
rest -= 4 ;
358
430
}
359
431
if (rest >= 2 ) {
360
- vmovups (xmm_src, ptr[param1 + offset]);
361
- vmaxps (xmm_dst, xmm_zero, xmm_src);
432
+ vmovq (xmm_src, ptr[param1 + offset]);
433
+ switch (type_) {
434
+ case operand_type::relu:
435
+ relu_xmm (xmm_dst, xmm_src, xmm_zero);
436
+ break ;
437
+ case operand_type::exp:
438
+ exp_xmm (xmm_dst, xmm_src, 2 , 3 , 4 , 5 );
439
+ break ;
440
+ default :
441
+ break ;
442
+ }
362
443
vmovq (ptr[param2 + offset], xmm_dst);
363
444
offset += sizeof (float ) * 2 ;
364
445
rest -= 2 ;
365
446
}
366
447
if (rest > 0 ) {
367
- vmovups (xmm_src, ptr[param1 + offset]);
368
- vmaxps (xmm_dst, xmm_zero, xmm_src);
448
+ // vmovups();
449
+ vmovss (xmm_src, ptr[param1 + offset]);
450
+
451
+ switch (type_) {
452
+ case operand_type::relu:
453
+ relu_xmm (xmm_dst, xmm_src, xmm_zero);
454
+ break ;
455
+ case operand_type::exp:
456
+ exp_xmm (xmm_dst, xmm_src, 2 , 3 , 4 , 5 );
457
+ break ;
458
+ default :
459
+ break ;
460
+ }
369
461
vmovss (ptr[param2 + offset], xmm_dst);
370
462
}
371
463
ret ();
0 commit comments