@@ -151,6 +151,132 @@ void ReluJitCode::generate() {
151
151
}
152
152
ret ();
153
153
}
154
+
155
+ bool VExpJitCode::init (int d) {
156
+ return MayIUse (avx) && d == 8 ; // only 8 yet
157
+ }
158
+
159
+ #define ALIGN32 __attribute__ ((aligned(32 )))
160
+ #define EXP_HIG 88 .3762626647949f
161
+ #define EXP_LOW -88 .3762626647949f
162
+ #define CEPHES_LOG2EF 1.44269504088896341
163
+ #define CEPHES_EXP_C1 0.693359375
164
+ #define CEPHES_EXP_C2 -2.12194440e-4
165
+ #define CEPHES_EXP_P0 1.9875691500E-4
166
+ #define CEPHES_EXP_P1 1.3981999507E-3
167
+ #define CEPHES_EXP_P2 8.3334519073E-3
168
+ #define CEPHES_EXP_P3 4.1665795894E-2
169
+ #define CEPHES_EXP_P4 1.6666665459E-1
170
+ #define CEPHES_EXP_P5 5.0000001201E-1
171
+
172
+ #define REPEAT_8TIMES (val ) val, val, val, val, val, val, val, val
173
+
174
+ #define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof (float )
175
+ #define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof (float )
176
+ #define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof (float )
177
+ #define OFFSET_EXP_LOG2EF 4 * AVX_FLOAT_BLOCK * sizeof (float )
178
+ #define OFFSET_EXP_C1 5 * AVX_FLOAT_BLOCK * sizeof (float )
179
+ #define OFFSET_EXP_C2 6 * AVX_FLOAT_BLOCK * sizeof (float )
180
+ #define OFFSET_EXP_P0 7 * AVX_FLOAT_BLOCK * sizeof (float )
181
+ #define OFFSET_EXP_P1 8 * AVX_FLOAT_BLOCK * sizeof (float )
182
+ #define OFFSET_EXP_P2 9 * AVX_FLOAT_BLOCK * sizeof (float )
183
+ #define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof (float )
184
+ #define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof (float )
185
+ #define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof (float )
186
+
187
+ static const float exp_float_consts[] ALIGN32 = {
188
+ REPEAT_8TIMES (1 .f ), REPEAT_8TIMES (0 .5f ),
189
+ REPEAT_8TIMES (EXP_HIG), REPEAT_8TIMES (EXP_LOW),
190
+ REPEAT_8TIMES (CEPHES_LOG2EF), REPEAT_8TIMES (CEPHES_EXP_C1),
191
+ REPEAT_8TIMES (CEPHES_EXP_C2), REPEAT_8TIMES (CEPHES_EXP_P0),
192
+ REPEAT_8TIMES (CEPHES_EXP_P1), REPEAT_8TIMES (CEPHES_EXP_P2),
193
+ REPEAT_8TIMES (CEPHES_EXP_P3), REPEAT_8TIMES (CEPHES_EXP_P4),
194
+ REPEAT_8TIMES (CEPHES_EXP_P5)};
195
+
196
+ static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES (0x7f )};
197
+ static int g_tmp_mem[16 ] ALIGN32 = {0 };
198
+
199
+ void VExpJitCode::generate () {
200
+ preCode ();
201
+ // push some?
202
+ // in: ymm0, out: ymm1
203
+ // use ymm 0~5 (and ymm 14~15 if avx only)
204
+ int offset = 0 ;
205
+ vmovups (ymm_src, ptr[param1 + offset]);
206
+ mov (reg_ptr_global, reinterpret_cast <size_t >(exp_float_consts));
207
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
208
+ vminps (ymm_src, ymm_src, ymm_tmp);
209
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
210
+ vmaxps (ymm_src, ymm_src, ymm_tmp);
211
+ // express exp(x) as exp(g + n*log(2))
212
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
213
+ vmulps (ymm_fx, ymm_src, ymm_tmp);
214
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
215
+ vaddps (ymm_fx, ymm_fx, ymm_tmp);
216
+ vroundps (ymm_fy, ymm_fx, 0x01 );
217
+ // if greater, substract 1
218
+ vcmpgtps (ymm_mask, ymm_fy, ymm_fx);
219
+ vmovaps (ymm_tmp, ptr[reg_ptr_global]);
220
+ vandps (ymm_mask, ymm_mask, ymm_tmp);
221
+ vsubps (ymm_fx, ymm_fy, ymm_mask);
222
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
223
+ vmulps (ymm_fy, ymm_fx, ymm_tmp);
224
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
225
+ vmulps (ymm_z, ymm_fx, ymm_tmp); // ymm_z use same with mask
226
+ vsubps (ymm_src, ymm_src, ymm_fy);
227
+ vsubps (ymm_src, ymm_src, ymm_z);
228
+ vmulps (ymm_z, ymm_src, ymm_src);
229
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
230
+ vmulps (ymm_dst, ymm_src, ymm_tmp);
231
+ for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
232
+ i += (AVX_FLOAT_BLOCK * sizeof (float ))) {
233
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
234
+ vaddps (ymm_dst, ymm_dst, ymm_tmp);
235
+ vmulps (ymm_dst, ymm_dst, ymm_src);
236
+ }
237
+ vmovaps (ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
238
+ vaddps (ymm_dst, ymm_dst, ymm_tmp);
239
+ vmulps (ymm_dst, ymm_dst, ymm_z);
240
+ vaddps (ymm_dst, ymm_dst, ymm_src);
241
+ vmovaps (ymm_tmp, ptr[reg_ptr_global]);
242
+ vaddps (ymm_dst, ymm_dst, ymm_tmp);
243
+
244
+ // build 2^n
245
+ ymm_t ymm_int = ymm_fx;
246
+ vcvttps2dq (ymm_int, ymm_fx);
247
+ mov (reg_ptr_global, reinterpret_cast <size_t >(exp_int_0x7f));
248
+ vmovdqa (ymm_tmp, ptr[reg_ptr_global]);
249
+ if (MayIUse (avx2)) {
250
+ vpaddd (ymm_int, ymm_int, ymm_tmp);
251
+ vpslld (ymm_int, ymm_int, 23 );
252
+ } else if (MayIUse (avx)) {
253
+ // use ymm_int, ymm_tmp and reg_ptr_global
254
+ xmm_t xtmp1 = xmm_t (ymm_int); // or magic number should equal the ymm_int
255
+ xmm_t xtmp2 = xmm_t (ymm_tmp); // or magic number should equal the ymm_tmp
256
+ mov (reg_ptr_global, reinterpret_cast <size_t >(g_tmp_mem));
257
+ vmovdqa (ptr[reg_ptr_global], ymm_int);
258
+ vmovdqa (ptr[reg_ptr_global + AVX_FLOAT_BLOCK * sizeof (float )], ymm_tmp);
259
+ vpaddd (xtmp1, xtmp1, xtmp2);
260
+ vpslld (xtmp1, xtmp1, 23 );
261
+ vmovdqa (ptr[reg_ptr_global], xtmp1);
262
+ // next 128bits
263
+ vmovdqa (xtmp1, ptr[reg_ptr_global + 4 /* xmm float block*/ * sizeof (float )]);
264
+ vmovdqa (xtmp2,
265
+ ptr[reg_ptr_global +
266
+ (AVX_FLOAT_BLOCK + 4 /* xmm float block*/ ) * sizeof (float )]);
267
+ vpaddd (xtmp1, xtmp1, xtmp2);
268
+ vpslld (xtmp1, xtmp1, 23 );
269
+ vmovdqa (ptr[reg_ptr_global + 4 /* xmm float block*/ * sizeof (float )], xtmp1);
270
+ // load out
271
+ vmovdqa (ymm_int, ptr[reg_ptr_global]);
272
+ }
273
+ vmulps (ymm_dst, ymm_dst, ymm_int);
274
+ vmovups (ptr[param2 + offset], ymm_dst);
275
+
276
+ // ret();
277
+ postCode ();
278
+ }
279
+
154
280
} // namespace gen
155
281
} // namespace jitkernel
156
282
} // namespace math
0 commit comments