Skip to content

Commit 7484355

Browse files
committed
clean code exp avx
1 parent b4751a3 commit 7484355

File tree

1 file changed

+46
-85
lines changed

1 file changed

+46
-85
lines changed

paddle/fluid/operators/math/jit_kernel_exp.cc

Lines changed: 46 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -141,50 +141,52 @@ typedef union imm_xmm_union {
141141
AVX2_BITOP_USING_SSE2(slli_epi32);
142142
AVX2_INTOP_USING_SSE2(add_epi32);
143143

144+
#define AVXEXP_BASE \
145+
__m256 tmp = _mm256_setzero_ps(), fx; \
146+
__m256 one = *reinterpret_cast<const __m256*>(_ps256_one); \
147+
__m256i imm0; \
148+
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi)); \
149+
x = _mm256_max_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_lo)); \
150+
/* express exp(x) as exp(g + n*log(2)) */ \
151+
fx = _mm256_mul_ps(x, \
152+
*reinterpret_cast<const __m256*>(_ps256_cephes_LOG2EF)); \
153+
fx = _mm256_add_ps(fx, *reinterpret_cast<const __m256*>(_ps256_0p5)); \
154+
tmp = _mm256_floor_ps(fx); \
155+
/* if greater, substract 1 */ \
156+
__m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); \
157+
mask = _mm256_and_ps(mask, one); \
158+
fx = _mm256_sub_ps(tmp, mask); \
159+
tmp = _mm256_mul_ps(fx, \
160+
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_C1)); \
161+
__m256 z = _mm256_mul_ps( \
162+
fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C2)); \
163+
x = _mm256_sub_ps(x, tmp); \
164+
x = _mm256_sub_ps(x, z); \
165+
z = _mm256_mul_ps(x, x); \
166+
__m256 y = *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p0); \
167+
y = _mm256_mul_ps(y, x); \
168+
y = _mm256_add_ps(y, \
169+
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p1)); \
170+
y = _mm256_mul_ps(y, x); \
171+
y = _mm256_add_ps(y, \
172+
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p2)); \
173+
y = _mm256_mul_ps(y, x); \
174+
y = _mm256_add_ps(y, \
175+
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p3)); \
176+
y = _mm256_mul_ps(y, x); \
177+
y = _mm256_add_ps(y, \
178+
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p4)); \
179+
y = _mm256_mul_ps(y, x); \
180+
y = _mm256_add_ps(y, \
181+
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p5)); \
182+
y = _mm256_mul_ps(y, z); \
183+
y = _mm256_add_ps(y, x); \
184+
y = _mm256_add_ps(y, one); \
185+
/* build 2^n */ \
186+
imm0 = _mm256_cvttps_epi32(fx)
187+
144188
__m256 ExpAVX(__m256 x) {
145-
__m256 tmp = _mm256_setzero_ps(), fx;
146-
__m256 one = *reinterpret_cast<const __m256*>(_ps256_one);
147-
__m256i imm0;
148-
149-
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi));
150-
x = _mm256_max_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_lo));
151-
152-
/* express exp(x) as exp(g + n*log(2)) */
153-
fx = _mm256_mul_ps(x, *reinterpret_cast<const __m256*>(_ps256_cephes_LOG2EF));
154-
fx = _mm256_add_ps(fx, *reinterpret_cast<const __m256*>(_ps256_0p5));
155-
156-
tmp = _mm256_floor_ps(fx);
157-
158-
/* if greater, substract 1 */
159-
__m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);
160-
mask = _mm256_and_ps(mask, one);
161-
fx = _mm256_sub_ps(tmp, mask);
162-
163-
tmp =
164-
_mm256_mul_ps(fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C1));
165-
__m256 z =
166-
_mm256_mul_ps(fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C2));
167-
x = _mm256_sub_ps(x, tmp);
168-
x = _mm256_sub_ps(x, z);
169-
z = _mm256_mul_ps(x, x);
170-
171-
__m256 y = *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p0);
172-
y = _mm256_mul_ps(y, x);
173-
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p1));
174-
y = _mm256_mul_ps(y, x);
175-
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p2));
176-
y = _mm256_mul_ps(y, x);
177-
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p3));
178-
y = _mm256_mul_ps(y, x);
179-
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p4));
180-
y = _mm256_mul_ps(y, x);
181-
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p5));
182-
y = _mm256_mul_ps(y, z);
183-
y = _mm256_add_ps(y, x);
184-
y = _mm256_add_ps(y, one);
185-
186-
/* build 2^n */
187-
imm0 = _mm256_cvttps_epi32(fx);
189+
AVXEXP_BASE;
188190
// two AVX2 instructions using SSE2
189191
imm0 = avx2_mm256_add_epi32(imm0,
190192
*reinterpret_cast<const __m256i*>(_pi256_0x7f));
@@ -197,48 +199,7 @@ __m256 ExpAVX(__m256 x) {
197199

198200
#ifdef __AVX2__
199201
__m256 ExpAVX2(__m256 x) {
200-
__m256 tmp = _mm256_setzero_ps(), fx;
201-
__m256 one = *reinterpret_cast<const __m256*>(_ps256_one);
202-
__m256i imm0;
203-
204-
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi));
205-
x = _mm256_max_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_lo));
206-
207-
/* express exp(x) as exp(g + n*log(2)) */
208-
fx = _mm256_mul_ps(x, *reinterpret_cast<const __m256*>(_ps256_cephes_LOG2EF));
209-
fx = _mm256_add_ps(fx, *reinterpret_cast<const __m256*>(_ps256_0p5));
210-
211-
tmp = _mm256_floor_ps(fx);
212-
213-
/* if greater, substract 1 */
214-
__m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);
215-
mask = _mm256_and_ps(mask, one);
216-
fx = _mm256_sub_ps(tmp, mask);
217-
218-
tmp =
219-
_mm256_mul_ps(fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C1));
220-
__m256 z =
221-
_mm256_mul_ps(fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C2));
222-
x = _mm256_sub_ps(x, tmp);
223-
x = _mm256_sub_ps(x, z);
224-
z = _mm256_mul_ps(x, x);
225-
__m256 y = *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p0);
226-
y = _mm256_mul_ps(y, x);
227-
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p1));
228-
y = _mm256_mul_ps(y, x);
229-
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p2));
230-
y = _mm256_mul_ps(y, x);
231-
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p3));
232-
y = _mm256_mul_ps(y, x);
233-
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p4));
234-
y = _mm256_mul_ps(y, x);
235-
y = _mm256_add_ps(y, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p5));
236-
y = _mm256_mul_ps(y, z);
237-
y = _mm256_add_ps(y, x);
238-
y = _mm256_add_ps(y, one);
239-
240-
/* build 2^n */
241-
imm0 = _mm256_cvttps_epi32(fx);
202+
AVXEXP_BASE;
242203
// two AVX2 instructions
243204
imm0 = _mm256_add_epi32(imm0, *reinterpret_cast<const __m256i*>(_pi256_0x7f));
244205
imm0 = _mm256_slli_epi32(imm0, 23);

0 commit comments

Comments
 (0)