@@ -141,50 +141,52 @@ typedef union imm_xmm_union {
141
141
AVX2_BITOP_USING_SSE2 (slli_epi32);
142
142
AVX2_INTOP_USING_SSE2 (add_epi32);
143
143
144
+ #define AVXEXP_BASE \
145
+ __m256 tmp = _mm256_setzero_ps(), fx; \
146
+ __m256 one = *reinterpret_cast <const __m256*>(_ps256_one); \
147
+ __m256i imm0; \
148
+ x = _mm256_min_ps(x, *reinterpret_cast <const __m256*>(_ps256_exp_hi)); \
149
+ x = _mm256_max_ps(x, *reinterpret_cast <const __m256*>(_ps256_exp_lo)); \
150
+ /* express exp(x) as exp(g + n*log(2)) */ \
151
+ fx = _mm256_mul_ps(x, \
152
+ *reinterpret_cast <const __m256*>(_ps256_cephes_LOG2EF)); \
153
+ fx = _mm256_add_ps(fx, *reinterpret_cast <const __m256*>(_ps256_0p5)); \
154
+ tmp = _mm256_floor_ps(fx); \
155
+ /* if greater, substract 1 */ \
156
+ __m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); \
157
+ mask = _mm256_and_ps(mask, one); \
158
+ fx = _mm256_sub_ps(tmp, mask); \
159
+ tmp = _mm256_mul_ps(fx, \
160
+ *reinterpret_cast <const __m256*>(_ps256_cephes_exp_C1)); \
161
+ __m256 z = _mm256_mul_ps( \
162
+ fx, *reinterpret_cast <const __m256*>(_ps256_cephes_exp_C2)); \
163
+ x = _mm256_sub_ps(x, tmp); \
164
+ x = _mm256_sub_ps(x, z); \
165
+ z = _mm256_mul_ps(x, x); \
166
+ __m256 y = *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p0); \
167
+ y = _mm256_mul_ps(y, x); \
168
+ y = _mm256_add_ps(y, \
169
+ *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p1)); \
170
+ y = _mm256_mul_ps(y, x); \
171
+ y = _mm256_add_ps(y, \
172
+ *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p2)); \
173
+ y = _mm256_mul_ps(y, x); \
174
+ y = _mm256_add_ps(y, \
175
+ *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p3)); \
176
+ y = _mm256_mul_ps(y, x); \
177
+ y = _mm256_add_ps(y, \
178
+ *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p4)); \
179
+ y = _mm256_mul_ps(y, x); \
180
+ y = _mm256_add_ps(y, \
181
+ *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p5)); \
182
+ y = _mm256_mul_ps(y, z); \
183
+ y = _mm256_add_ps(y, x); \
184
+ y = _mm256_add_ps(y, one); \
185
+ /* build 2^n */ \
186
+ imm0 = _mm256_cvttps_epi32(fx)
187
+
144
188
__m256 ExpAVX (__m256 x) {
145
- __m256 tmp = _mm256_setzero_ps (), fx;
146
- __m256 one = *reinterpret_cast <const __m256*>(_ps256_one);
147
- __m256i imm0;
148
-
149
- x = _mm256_min_ps (x, *reinterpret_cast <const __m256*>(_ps256_exp_hi));
150
- x = _mm256_max_ps (x, *reinterpret_cast <const __m256*>(_ps256_exp_lo));
151
-
152
- /* express exp(x) as exp(g + n*log(2)) */
153
- fx = _mm256_mul_ps (x, *reinterpret_cast <const __m256*>(_ps256_cephes_LOG2EF));
154
- fx = _mm256_add_ps (fx, *reinterpret_cast <const __m256*>(_ps256_0p5));
155
-
156
- tmp = _mm256_floor_ps (fx);
157
-
158
- /* if greater, substract 1 */
159
- __m256 mask = _mm256_cmp_ps (tmp, fx, _CMP_GT_OS);
160
- mask = _mm256_and_ps (mask, one);
161
- fx = _mm256_sub_ps (tmp, mask);
162
-
163
- tmp =
164
- _mm256_mul_ps (fx, *reinterpret_cast <const __m256*>(_ps256_cephes_exp_C1));
165
- __m256 z =
166
- _mm256_mul_ps (fx, *reinterpret_cast <const __m256*>(_ps256_cephes_exp_C2));
167
- x = _mm256_sub_ps (x, tmp);
168
- x = _mm256_sub_ps (x, z);
169
- z = _mm256_mul_ps (x, x);
170
-
171
- __m256 y = *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p0);
172
- y = _mm256_mul_ps (y, x);
173
- y = _mm256_add_ps (y, *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p1));
174
- y = _mm256_mul_ps (y, x);
175
- y = _mm256_add_ps (y, *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p2));
176
- y = _mm256_mul_ps (y, x);
177
- y = _mm256_add_ps (y, *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p3));
178
- y = _mm256_mul_ps (y, x);
179
- y = _mm256_add_ps (y, *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p4));
180
- y = _mm256_mul_ps (y, x);
181
- y = _mm256_add_ps (y, *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p5));
182
- y = _mm256_mul_ps (y, z);
183
- y = _mm256_add_ps (y, x);
184
- y = _mm256_add_ps (y, one);
185
-
186
- /* build 2^n */
187
- imm0 = _mm256_cvttps_epi32 (fx);
189
+ AVXEXP_BASE;
188
190
// two AVX2 instructions using SSE2
189
191
imm0 = avx2_mm256_add_epi32 (imm0,
190
192
*reinterpret_cast <const __m256i*>(_pi256_0x7f));
@@ -197,48 +199,7 @@ __m256 ExpAVX(__m256 x) {
197
199
198
200
#ifdef __AVX2__
199
201
__m256 ExpAVX2 (__m256 x) {
200
- __m256 tmp = _mm256_setzero_ps (), fx;
201
- __m256 one = *reinterpret_cast <const __m256*>(_ps256_one);
202
- __m256i imm0;
203
-
204
- x = _mm256_min_ps (x, *reinterpret_cast <const __m256*>(_ps256_exp_hi));
205
- x = _mm256_max_ps (x, *reinterpret_cast <const __m256*>(_ps256_exp_lo));
206
-
207
- /* express exp(x) as exp(g + n*log(2)) */
208
- fx = _mm256_mul_ps (x, *reinterpret_cast <const __m256*>(_ps256_cephes_LOG2EF));
209
- fx = _mm256_add_ps (fx, *reinterpret_cast <const __m256*>(_ps256_0p5));
210
-
211
- tmp = _mm256_floor_ps (fx);
212
-
213
- /* if greater, substract 1 */
214
- __m256 mask = _mm256_cmp_ps (tmp, fx, _CMP_GT_OS);
215
- mask = _mm256_and_ps (mask, one);
216
- fx = _mm256_sub_ps (tmp, mask);
217
-
218
- tmp =
219
- _mm256_mul_ps (fx, *reinterpret_cast <const __m256*>(_ps256_cephes_exp_C1));
220
- __m256 z =
221
- _mm256_mul_ps (fx, *reinterpret_cast <const __m256*>(_ps256_cephes_exp_C2));
222
- x = _mm256_sub_ps (x, tmp);
223
- x = _mm256_sub_ps (x, z);
224
- z = _mm256_mul_ps (x, x);
225
- __m256 y = *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p0);
226
- y = _mm256_mul_ps (y, x);
227
- y = _mm256_add_ps (y, *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p1));
228
- y = _mm256_mul_ps (y, x);
229
- y = _mm256_add_ps (y, *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p2));
230
- y = _mm256_mul_ps (y, x);
231
- y = _mm256_add_ps (y, *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p3));
232
- y = _mm256_mul_ps (y, x);
233
- y = _mm256_add_ps (y, *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p4));
234
- y = _mm256_mul_ps (y, x);
235
- y = _mm256_add_ps (y, *reinterpret_cast <const __m256*>(_ps256_cephes_exp_p5));
236
- y = _mm256_mul_ps (y, z);
237
- y = _mm256_add_ps (y, x);
238
- y = _mm256_add_ps (y, one);
239
-
240
- /* build 2^n */
241
- imm0 = _mm256_cvttps_epi32 (fx);
202
+ AVXEXP_BASE;
242
203
// two AVX2 instructions
243
204
imm0 = _mm256_add_epi32 (imm0, *reinterpret_cast <const __m256i*>(_pi256_0x7f));
244
205
imm0 = _mm256_slli_epi32 (imm0, 23 );
0 commit comments