@@ -60,60 +60,53 @@ void VXXJitCode::generate() {
60
60
offset += sizeof (float ) * YMM_FLOAT_BLOCK;
61
61
}
62
62
int rest = num_ % YMM_FLOAT_BLOCK;
63
- if (rest >= 4 ) {
64
- if (scalar_index_ != 1 ) {
65
- vmovups (xmm_src1, ptr[param1 + offset]);
66
- }
67
- if (scalar_index_ != 2 ) {
68
- vmovups (xmm_src2, ptr[param2 + offset]);
69
- }
70
- if (type_ == operand_type::mul) {
71
- vmulps (xmm_dst, xmm_src1, xmm_src2);
72
- } else if (type_ == operand_type::add) {
73
- vaddps (xmm_dst, xmm_src1, xmm_src2);
74
- }
75
- if (with_relu_) {
76
- vmaxps (xmm_dst, xmm_zero, xmm_dst);
77
- }
78
- vmovups (ptr[param3 + offset], xmm_dst);
79
- offset += sizeof (float ) * 4 ;
80
- rest -= 4 ;
81
- }
82
- if (rest >= 2 ) {
83
- if (scalar_index_ != 1 ) {
84
- vmovq (xmm_src1, ptr[param1 + offset]);
85
- }
86
- if (scalar_index_ != 2 ) {
87
- vmovq (xmm_src2, ptr[param2 + offset]);
63
+ int block = XMM_FLOAT_BLOCK;
64
+ while (rest > 0 ) {
65
+ if (rest >= 4 ) {
66
+ if (scalar_index_ != 1 ) {
67
+ vmovups (xmm_src1, ptr[param1 + offset]);
68
+ }
69
+ if (scalar_index_ != 2 ) {
70
+ vmovups (xmm_src2, ptr[param2 + offset]);
71
+ }
72
+ } else if (rest >= 2 ) {
73
+ if (scalar_index_ != 1 ) {
74
+ vmovq (xmm_src1, ptr[param1 + offset]);
75
+ }
76
+ if (scalar_index_ != 2 ) {
77
+ vmovq (xmm_src2, ptr[param2 + offset]);
78
+ }
79
+ } else {
80
+ if (scalar_index_ != 1 ) {
81
+ vmovss (xmm_src1, ptr[param1 + offset]);
82
+ }
83
+ if (scalar_index_ != 2 ) {
84
+ vmovss (xmm_src2, ptr[param2 + offset]);
85
+ }
88
86
}
89
- if (type_ == operand_type::mul) {
90
- vmulps (xmm_dst, xmm_src1, xmm_src2);
91
- } else if (type_ == operand_type::add) {
92
- vaddps (xmm_dst, xmm_src1, xmm_src2);
87
+ switch (type_) {
88
+ case operand_type::mul:
89
+ vmulps (xmm_dst, xmm_src1, xmm_src2);
90
+ break ;
91
+ case operand_type::add:
92
+ vaddps (xmm_dst, xmm_src1, xmm_src2);
93
+ break ;
94
+ default :
95
+ break ;
93
96
}
94
97
if (with_relu_) {
95
98
vmaxps (xmm_dst, xmm_zero, xmm_dst);
96
99
}
97
- vmovq (ptr[param3 + offset], xmm_dst);
98
- offset += sizeof (float ) * 2 ;
99
- rest -= 2 ;
100
- }
101
- if (rest > 0 ) {
102
- if (scalar_index_ != 1 ) {
103
- vmovss (xmm_src1, ptr[param1 + offset]);
104
- }
105
- if (scalar_index_ != 2 ) {
106
- vmovss (xmm_src2, ptr[param2 + offset]);
107
- }
108
- if (type_ == operand_type::mul) {
109
- vmulss (xmm_dst, xmm_src1, xmm_src2);
110
- } else if (type_ == operand_type::add) {
111
- vaddss (xmm_dst, xmm_src1, xmm_src2);
100
+ if (rest >= 4 ) {
101
+ vmovups (ptr[param3 + offset], xmm_dst);
102
+ } else if (rest >= 2 ) {
103
+ vmovq (ptr[param3 + offset], xmm_dst);
104
+ } else {
105
+ vmovss (ptr[param3 + offset], xmm_dst);
112
106
}
113
- if (with_relu_) {
114
- vmaxps (xmm_dst, xmm_zero, xmm_dst);
115
- }
116
- vmovss (ptr[param3 + offset], xmm_dst);
107
+ offset += sizeof (float ) * block;
108
+ rest -= block;
109
+ block /= 2 ;
117
110
}
118
111
ret ();
119
112
}
@@ -175,11 +168,9 @@ static int g_tmp_mem[16] ALIGN32 = {0};
175
168
176
169
bool VActJitCode::init (int d, operand_type type) {
177
170
bool ok = MayIUse (avx);
178
- if (type == operand_type::relu) {
171
+ if (type == operand_type::relu || type == operand_type::exp) {
172
+ // TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
179
173
return ok;
180
- } else if (type == operand_type::exp) {
181
- // exp is slower than mkl when d >= 256
182
- return ok; // && d % 4 == 0 && d < 256;
183
174
} else {
184
175
// TODO(TJ): support more
185
176
return ok && d % 8 == 0 ;
@@ -412,24 +403,15 @@ void VActJitCode::generate() {
412
403
return ;
413
404
}
414
405
int rest = num_ % YMM_FLOAT_BLOCK;
415
- if (rest >= 4 ) {
416
- vmovups (xmm_src, ptr[param1 + offset]);
417
- switch (type_) {
418
- case operand_type::relu:
419
- relu_xmm (xmm_dst, xmm_src, xmm_zero);
420
- break ;
421
- case operand_type::exp:
422
- exp_xmm (xmm_dst, xmm_src, 2 , 3 , 4 , 5 );
423
- break ;
424
- default :
425
- break ;
406
+ int block = XMM_FLOAT_BLOCK;
407
+ while (rest > 0 ) {
408
+ if (rest >= 4 ) {
409
+ vmovups (xmm_src, ptr[param1 + offset]);
410
+ } else if (rest >= 2 ) {
411
+ vmovq (xmm_src, ptr[param1 + offset]);
412
+ } else {
413
+ vmovss (xmm_src, ptr[param1 + offset]);
426
414
}
427
- vmovups (ptr[param2 + offset], xmm_dst);
428
- offset += sizeof (float ) * 4 ;
429
- rest -= 4 ;
430
- }
431
- if (rest >= 2 ) {
432
- vmovq (xmm_src, ptr[param1 + offset]);
433
415
switch (type_) {
434
416
case operand_type::relu:
435
417
relu_xmm (xmm_dst, xmm_src, xmm_zero);
@@ -440,25 +422,16 @@ void VActJitCode::generate() {
440
422
default :
441
423
break ;
442
424
}
443
- vmovq (ptr[param2 + offset], xmm_dst);
444
- offset += sizeof (float ) * 2 ;
445
- rest -= 2 ;
446
- }
447
- if (rest > 0 ) {
448
- // vmovups();
449
- vmovss (xmm_src, ptr[param1 + offset]);
450
-
451
- switch (type_) {
452
- case operand_type::relu:
453
- relu_xmm (xmm_dst, xmm_src, xmm_zero);
454
- break ;
455
- case operand_type::exp:
456
- exp_xmm (xmm_dst, xmm_src, 2 , 3 , 4 , 5 );
457
- break ;
458
- default :
459
- break ;
425
+ if (rest >= 4 ) {
426
+ vmovups (ptr[param2 + offset], xmm_dst);
427
+ } else if (rest >= 2 ) {
428
+ vmovq (ptr[param2 + offset], xmm_dst);
429
+ } else {
430
+ vmovss (ptr[param2 + offset], xmm_dst);
460
431
}
461
- vmovss (ptr[param2 + offset], xmm_dst);
432
+ offset += sizeof (float ) * block;
433
+ rest -= block;
434
+ block /= 2 ;
462
435
}
463
436
ret ();
464
437
}
0 commit comments