@@ -41,7 +41,7 @@ void VXXJitCode::generate() {
41
41
} else if (scalar_index_ == 2 ) {
42
42
vbroadcastss (ymm_src2, ptr[param2]);
43
43
}
44
- for (int i = 0 ; i < num_ / AVX_FLOAT_BLOCK ; ++i) {
44
+ for (int i = 0 ; i < num_ / YMM_FLOAT_BLOCK ; ++i) {
45
45
if (scalar_index_ != 1 ) {
46
46
vmovups (ymm_src1, ptr[param1 + offset]);
47
47
}
@@ -57,9 +57,9 @@ void VXXJitCode::generate() {
57
57
vmaxps (ymm_dst, ymm_zero, ymm_dst);
58
58
}
59
59
vmovups (ptr[param3 + offset], ymm_dst);
60
- offset += sizeof (float ) * AVX_FLOAT_BLOCK ;
60
+ offset += sizeof (float ) * YMM_FLOAT_BLOCK ;
61
61
}
62
- int rest = num_ % AVX_FLOAT_BLOCK ;
62
+ int rest = num_ % YMM_FLOAT_BLOCK ;
63
63
if (rest >= 4 ) {
64
64
if (scalar_index_ != 1 ) {
65
65
vmovups (xmm_src1, ptr[param1 + offset]);
@@ -118,18 +118,237 @@ void VXXJitCode::generate() {
118
118
ret ();
119
119
}
120
120
121
- bool ReluJitCode::init (int d) { return MayIUse (avx); }
121
+ #define ALIGN32 __attribute__ ((aligned(32 )))
122
+ #define EXP_HIG 88 .3762626647949f
123
+ #define EXP_LOW -88 .3762626647949f
124
+ #define CEPHES_LOG2EF 1.44269504088896341
125
+ #define CEPHES_EXP_C1 0.693359375
126
+ #define CEPHES_EXP_C2 -2.12194440e-4
127
+ #define CEPHES_EXP_P0 1.9875691500E-4
128
+ #define CEPHES_EXP_P1 1.3981999507E-3
129
+ #define CEPHES_EXP_P2 8.3334519073E-3
130
+ #define CEPHES_EXP_P3 4.1665795894E-2
131
+ #define CEPHES_EXP_P4 1.6666665459E-1
132
+ #define CEPHES_EXP_P5 5.0000001201E-1
122
133
123
- void ReluJitCode::generate () {
124
- int offset = 0 ;
134
+ #define REPEAT_8TIMES (val ) val, val, val, val, val, val, val, val
135
+
136
+ #define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof (float )
137
+ #define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof (float )
138
+ #define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof (float )
139
+ #define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof (float )
140
+ #define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof (float )
141
+ #define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof (float )
142
+ #define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof (float )
143
+ #define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof (float )
144
+ #define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof (float )
145
+ #define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof (float )
146
+ #define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof (float )
147
+ #define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof (float )
148
+ #define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof (float )
149
+ #define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof (float )
150
+ #define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof (float )
151
+ #define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof (float )
152
+ #define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof (float )
153
+
154
+ static const float exp_float_consts[] ALIGN32 = {
155
+ REPEAT_8TIMES (1 .f ),
156
+ REPEAT_8TIMES (2 .f ),
157
+ REPEAT_8TIMES (0 .5f ),
158
+ REPEAT_8TIMES (EXP_HIG),
159
+ REPEAT_8TIMES (EXP_LOW),
160
+ REPEAT_8TIMES (CEPHES_LOG2EF),
161
+ REPEAT_8TIMES (CEPHES_EXP_C1),
162
+ REPEAT_8TIMES (CEPHES_EXP_C2),
163
+ REPEAT_8TIMES (CEPHES_EXP_P0),
164
+ REPEAT_8TIMES (CEPHES_EXP_P1),
165
+ REPEAT_8TIMES (CEPHES_EXP_P2),
166
+ REPEAT_8TIMES (CEPHES_EXP_P3),
167
+ REPEAT_8TIMES (CEPHES_EXP_P4),
168
+ REPEAT_8TIMES (CEPHES_EXP_P5),
169
+ REPEAT_8TIMES (EXP_MAX_INPUT),
170
+ REPEAT_8TIMES (SIGMOID_THRESHOLD_MAX),
171
+ REPEAT_8TIMES (SIGMOID_THRESHOLD_MIN)};
172
+
173
+ static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES (0x7f )};
174
+ static int g_tmp_mem[16 ] ALIGN32 = {0 };
175
+
176
+ bool VActJitCode::init (int d, operand_type type) {
177
+ bool ok = MayIUse (avx);
178
+ if (type == operand_type::relu) {
179
+ return ok;
180
+ } else if (type == operand_type::exp) {
181
+ // exp is slower than mkl when d >= 256
182
+ return ok && d % 8 == 0 && d < 256 ;
183
+ } else {
184
+ // TODO(TJ): support more
185
+ return ok && d % 8 == 0 ;
186
+ }
187
+ }
188
+
189
+ void VActJitCode::relu_ymm (ymm_t & ymm_dst, ymm_t & ymm_src, ymm_t & ymm_zero) {
190
+ vmaxps (ymm_dst, ymm_zero, ymm_src);
191
+ }
192
+
193
+ void VActJitCode::exp_ymm (ymm_t & ymm_dst, ymm_t & ymm_src, int fx_idx,
194
+ int fy_idx, int mask_idx, int tmp_idx) {
195
+ assert (ymm_src.getIdx () != ymm_dst.getIdx ()); // TODO(TJ): use enfore
196
+ // check all idx can not equal
197
+ ymm_t ymm_fx = ymm_t (fx_idx);
198
+ ymm_t ymm_fy = ymm_t (fy_idx);
199
+ ymm_t ymm_mask = ymm_t (mask_idx);
200
+ ymm_t ymm_tmp = ymm_t (tmp_idx);
201
+ reg64_t reg_ptr_global = rax;
202
+ push (reg_ptr_global);
203
+ mov (reg_ptr_global, reinterpret_cast <size_t >(exp_float_consts));
204
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
205
+ vminps (ymm_src, ymm_src, ymm_tmp);
206
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
207
+ vmaxps (ymm_src, ymm_src, ymm_tmp);
208
+ // express exp(x) as exp(g + n*log(2))
209
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
210
+ vmulps (ymm_fx, ymm_src, ymm_tmp);
211
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
212
+ vaddps (ymm_fx, ymm_fx, ymm_tmp);
213
+ vroundps (ymm_fy, ymm_fx, 0x01 );
214
+ // if greater, substract 1
215
+ vcmpgtps (ymm_mask, ymm_fy, ymm_fx);
216
+ vmovaps (ymm_tmp, ptr[reg_ptr_global]);
217
+ vandps (ymm_mask, ymm_mask, ymm_tmp);
218
+ vsubps (ymm_fx, ymm_fy, ymm_mask);
219
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
220
+ vmulps (ymm_fy, ymm_fx, ymm_tmp);
221
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
222
+ ymm_t ymm_z = ymm_t (ymm_mask.getIdx ());
223
+ vmulps (ymm_z, ymm_fx, ymm_tmp);
224
+ vsubps (ymm_src, ymm_src, ymm_fy);
225
+ vsubps (ymm_src, ymm_src, ymm_z);
226
+ vmulps (ymm_z, ymm_src, ymm_src);
227
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
228
+ vmulps (ymm_dst, ymm_src, ymm_tmp);
229
+ for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
230
+ i += (YMM_FLOAT_BLOCK * sizeof (float ))) {
231
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
232
+ vaddps (ymm_dst, ymm_dst, ymm_tmp);
233
+ vmulps (ymm_dst, ymm_dst, ymm_src);
234
+ }
235
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
236
+ vaddps (ymm_dst, ymm_dst, ymm_tmp);
237
+ vmulps (ymm_dst, ymm_dst, ymm_z);
238
+ vaddps (ymm_dst, ymm_dst, ymm_src);
239
+ vmovaps (ymm_tmp, ptr[reg_ptr_global]);
240
+ vaddps (ymm_dst, ymm_dst, ymm_tmp);
241
+ // build 2^n
242
+ ymm_t ymm_int = ymm_fx;
243
+ vcvttps2dq (ymm_int, ymm_fx);
244
+ mov (reg_ptr_global, reinterpret_cast <size_t >(exp_int_0x7f));
245
+ vmovdqa (ymm_tmp, ptr[reg_ptr_global]);
246
+ if (MayIUse (avx2)) {
247
+ vpaddd (ymm_int, ymm_int, ymm_tmp);
248
+ vpslld (ymm_int, ymm_int, 23 );
249
+ } else if (MayIUse (avx)) {
250
+ xmm_t xtmp1 = xmm_t (ymm_int.getIdx ());
251
+ xmm_t xtmp2 = xmm_t (ymm_tmp.getIdx ());
252
+ reg64_t reg_ptr_tmp = reg_ptr_global;
253
+ mov (reg_ptr_tmp, reinterpret_cast <size_t >(g_tmp_mem));
254
+ vmovdqa (ptr[reg_ptr_tmp], ymm_int);
255
+ vmovdqa (ptr[reg_ptr_tmp + YMM_FLOAT_BLOCK * sizeof (float )], ymm_tmp);
256
+ vpaddd (xtmp1, xtmp1, xtmp2);
257
+ vpslld (xtmp1, xtmp1, 23 );
258
+ vmovdqa (ptr[reg_ptr_tmp], xtmp1);
259
+ // next 128bits
260
+ vmovdqa (xtmp1, ptr[reg_ptr_tmp + 4 /* xmm float block*/ * sizeof (float )]);
261
+ vmovdqa (xtmp2,
262
+ ptr[reg_ptr_tmp +
263
+ (YMM_FLOAT_BLOCK + 4 /* xmm float block*/ ) * sizeof (float )]);
264
+ vpaddd (xtmp1, xtmp1, xtmp2);
265
+ vpslld (xtmp1, xtmp1, 23 );
266
+ vmovdqa (ptr[reg_ptr_tmp + 4 /* xmm float block*/ * sizeof (float )], xtmp1);
267
+ // load out
268
+ vmovdqa (ymm_int, ptr[reg_ptr_tmp]);
269
+ }
270
+ vmulps (ymm_dst, ymm_dst, ymm_int);
271
+ pop (reg_ptr_global);
272
+ }
273
+
274
+ void VActJitCode::sigmoid_ymm (ymm_t & ymm_dst, ymm_t & ymm_src, int fx_idx,
275
+ int fy_idx, int mask_idx, int tmp_idx) {
276
+ // y = 1 / (1 + e^-x)
277
+ ymm_t ymm_tmp = ymm_t (tmp_idx);
278
+ reg64_t reg_ptr_global = rax;
279
+ push (reg_ptr_global);
280
+ mov (reg_ptr_global, reinterpret_cast <size_t >(exp_float_consts));
281
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
282
+ vminps (ymm_src, ymm_src, ymm_tmp);
283
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]);
284
+ vmaxps (ymm_src, ymm_src, ymm_tmp);
285
+ vxorps (ymm_tmp, ymm_tmp, ymm_tmp);
286
+ vsubps (ymm_src, ymm_tmp, ymm_src);
287
+ exp_ymm (ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
288
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
289
+ vaddps (ymm_dst, ymm_dst, ymm_tmp);
290
+ vdivps (ymm_dst, ymm_tmp, ymm_dst);
291
+ pop (reg_ptr_global);
292
+ }
293
+
294
+ void VActJitCode::tanh_ymm (ymm_t & ymm_dst, ymm_t & ymm_src, int fx_idx,
295
+ int fy_idx, int mask_idx, int tmp_idx) {
296
+ // y = 2 / (1 + e^(-2x)) - 1
297
+ ymm_t ymm_tmp = ymm_t (tmp_idx);
298
+ ymm_t ymm_zero = ymm_t (mask_idx);
299
+ reg64_t reg_ptr_global = rax;
300
+ push (reg_ptr_global);
301
+ mov (reg_ptr_global, reinterpret_cast <size_t >(exp_float_consts));
302
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
125
303
vxorps (ymm_zero, ymm_zero, ymm_zero);
126
- for (int i = 0 ; i < num_ / AVX_FLOAT_BLOCK; ++i) {
304
+ vsubps (ymm_tmp, ymm_zero, ymm_tmp);
305
+ vmulps (ymm_src, ymm_src, ymm_tmp);
306
+ exp_ymm (ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
307
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
308
+ vaddps (ymm_dst, ymm_dst, ymm_tmp);
309
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
310
+ vdivps (ymm_dst, ymm_tmp, ymm_dst);
311
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
312
+ vsubps (ymm_dst, ymm_dst, ymm_tmp);
313
+ pop (reg_ptr_global);
314
+ }
315
+
316
+ void VActJitCode::generate () {
317
+ xmm_t xmm_zero = xmm_t (2 );
318
+ ymm_t ymm_zero = ymm_t (2 );
319
+ if (type_ == operand_type::relu) {
320
+ vxorps (ymm_zero, ymm_zero, ymm_zero);
321
+ }
322
+ int offset = 0 ;
323
+ for (int i = 0 ; i < num_ / YMM_FLOAT_BLOCK; ++i) {
127
324
vmovups (ymm_src, ptr[param1 + offset]);
128
- vmaxps (ymm_dst, ymm_zero, ymm_src);
325
+ switch (type_) {
326
+ case operand_type::relu:
327
+ relu_ymm (ymm_dst, ymm_src, ymm_zero);
328
+ break ;
329
+ case operand_type::exp:
330
+ exp_ymm (ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
331
+ break ;
332
+ case operand_type::sigmoid:
333
+ sigmoid_ymm (ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
334
+ break ;
335
+ case operand_type::tanh:
336
+ tanh_ymm (ymm_dst, ymm_src, 2 , 3 , 4 , 5 );
337
+ break ;
338
+ case operand_type::identity:
339
+ break ;
340
+ default :
341
+ break ;
342
+ }
129
343
vmovups (ptr[param2 + offset], ymm_dst);
130
- offset += sizeof (float ) * AVX_FLOAT_BLOCK;
344
+ offset += sizeof (float ) * YMM_FLOAT_BLOCK;
345
+ }
346
+ if (type_ != operand_type::relu) {
347
+ // TODO(TJ): remove me
348
+ ret ();
349
+ return ;
131
350
}
132
- int rest = num_ % AVX_FLOAT_BLOCK ;
351
+ int rest = num_ % YMM_FLOAT_BLOCK ;
133
352
if (rest >= 4 ) {
134
353
vmovups (xmm_src, ptr[param1 + offset]);
135
354
vmaxps (xmm_dst, xmm_zero, xmm_src);
@@ -151,6 +370,7 @@ void ReluJitCode::generate() {
151
370
}
152
371
ret ();
153
372
}
373
+
154
374
} // namespace gen
155
375
} // namespace jitkernel
156
376
} // namespace math
0 commit comments