@@ -152,10 +152,6 @@ void ReluJitCode::generate() {
152
152
ret ();
153
153
}
154
154
155
- bool VExpJitCode::init (int d) {
156
- return MayIUse (avx) && d == 8 ; // only 8 yet
157
- }
158
-
159
155
#define ALIGN32 __attribute__ ((aligned(32 )))
160
156
#define EXP_HIG 88 .3762626647949f
161
157
#define EXP_LOW -88 .3762626647949f
@@ -171,6 +167,7 @@ bool VExpJitCode::init(int d) {
171
167
172
168
#define REPEAT_8TIMES (val ) val, val, val, val, val, val, val, val
173
169
170
+ #define OFFSET_EXP_ONE 0 * AVX_FLOAT_BLOCK * sizeof (float )
174
171
#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof (float )
175
172
#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof (float )
176
173
#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof (float )
@@ -183,24 +180,43 @@ bool VExpJitCode::init(int d) {
183
180
#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof (float )
184
181
#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof (float )
185
182
#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof (float )
183
+ #define OFFSET_EXP_MAX_INPUT 13 * AVX_FLOAT_BLOCK * sizeof (float )
184
+ #define OFFSET_SIGMOID_MAX 14 * AVX_FLOAT_BLOCK * sizeof (float )
185
+ #define OFFSET_SIGMOID_MIN 15 * AVX_FLOAT_BLOCK * sizeof (float )
186
186
187
187
static const float exp_float_consts[] ALIGN32 = {
188
- REPEAT_8TIMES (1 .f ), REPEAT_8TIMES (0 .5f ),
189
- REPEAT_8TIMES (EXP_HIG), REPEAT_8TIMES (EXP_LOW),
190
- REPEAT_8TIMES (CEPHES_LOG2EF), REPEAT_8TIMES (CEPHES_EXP_C1),
191
- REPEAT_8TIMES (CEPHES_EXP_C2), REPEAT_8TIMES (CEPHES_EXP_P0),
192
- REPEAT_8TIMES (CEPHES_EXP_P1), REPEAT_8TIMES (CEPHES_EXP_P2),
193
- REPEAT_8TIMES (CEPHES_EXP_P3), REPEAT_8TIMES (CEPHES_EXP_P4),
194
- REPEAT_8TIMES (CEPHES_EXP_P5)};
188
+ REPEAT_8TIMES (1 .f ),
189
+ REPEAT_8TIMES (0 .5f ),
190
+ REPEAT_8TIMES (EXP_HIG),
191
+ REPEAT_8TIMES (EXP_LOW),
192
+ REPEAT_8TIMES (CEPHES_LOG2EF),
193
+ REPEAT_8TIMES (CEPHES_EXP_C1),
194
+ REPEAT_8TIMES (CEPHES_EXP_C2),
195
+ REPEAT_8TIMES (CEPHES_EXP_P0),
196
+ REPEAT_8TIMES (CEPHES_EXP_P1),
197
+ REPEAT_8TIMES (CEPHES_EXP_P2),
198
+ REPEAT_8TIMES (CEPHES_EXP_P3),
199
+ REPEAT_8TIMES (CEPHES_EXP_P4),
200
+ REPEAT_8TIMES (CEPHES_EXP_P5),
201
+ REPEAT_8TIMES (EXP_MAX_INPUT),
202
+ REPEAT_8TIMES (SIGMOID_THRESHOLD_MAX),
203
+ REPEAT_8TIMES (SIGMOID_THRESHOLD_MIN)};
195
204
196
205
static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES (0x7f )};
197
206
static int g_tmp_mem[16 ] ALIGN32 = {0 };
198
207
199
- void VExpJitCode::generate () {
200
- // in: ymm0, out: ymm1
201
- // use ymm 0~5, rax
202
- int offset = 0 ;
203
- vmovups (ymm_src, ptr[param1 + offset]);
208
+ bool VExpJitCode::init (int d) {
209
+ return MayIUse (avx) && d == 8 ; // only 8 yet
210
+ }
211
+
212
+ void VExpJitCode::exp_ymm (ymm_t & ymm_src, ymm_t & ymm_dst) {
213
+ // use reg rax and ymm 2~5
214
+ reg64_t reg_ptr_global = rax;
215
+ ymm_t ymm_fx = ymm_t (2 );
216
+ ymm_t ymm_fy = ymm_t (3 );
217
+ ymm_t ymm_mask = ymm_t (4 );
218
+ ymm_t ymm_tmp = ymm_t (5 );
219
+ push (reg_ptr_global);
204
220
mov (reg_ptr_global, reinterpret_cast <size_t >(exp_float_consts));
205
221
vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
206
222
vminps (ymm_src, ymm_src, ymm_tmp);
@@ -269,8 +285,45 @@ void VExpJitCode::generate() {
269
285
vmovdqa (ymm_int, ptr[reg_ptr_tmp]);
270
286
}
271
287
vmulps (ymm_dst, ymm_dst, ymm_int);
288
+ pop (reg_ptr_global);
289
+ }
290
+
291
+ void VExpJitCode::generate () {
292
+ int offset = 0 ;
293
+ vmovups (ymm_src, ptr[param1 + offset]);
294
+ exp_ymm (ymm_src, ymm_dst);
272
295
vmovups (ptr[param2 + offset], ymm_dst);
296
+ ret ();
297
+ }
298
+
299
+ bool VSigmoidJitCode::init (int d) {
300
+ return MayIUse (avx) && d == 8 ; // only 8 yet
301
+ }
273
302
303
+ void VSigmoidJitCode::sigmoid_ymm (ymm_t & ymm_src, ymm_t & ymm_dst) {
304
+ // use ymm2
305
+ reg64_t reg_ptr_global = rax;
306
+ ymm_t ymm_tmp = ymm_t (2 );
307
+ push (reg_ptr_global);
308
+ mov (reg_ptr_global, reinterpret_cast <size_t >(exp_float_consts));
309
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
310
+ vminps (ymm_src, ymm_src, ymm_tmp);
311
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]);
312
+ vmaxps (ymm_src, ymm_src, ymm_tmp);
313
+ vxorps (ymm_tmp, ymm_tmp, ymm_tmp);
314
+ vsubps (ymm_src, ymm_tmp, ymm_src);
315
+ exp_ymm (ymm_src, ymm_dst);
316
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
317
+ vaddps (ymm_dst, ymm_dst, ymm_tmp);
318
+ vdivps (ymm_dst, ymm_tmp, ymm_dst);
319
+ pop (reg_ptr_global);
320
+ }
321
+
322
+ void VSigmoidJitCode::generate () {
323
+ int offset = 0 ;
324
+ vmovups (ymm_src, ptr[param1 + offset]);
325
+ sigmoid_ymm (ymm_src, ymm_dst);
326
+ vmovups (ptr[param2 + offset], ymm_dst);
274
327
ret ();
275
328
}
276
329
0 commit comments