Skip to content

Commit 1e06a32

Browse files
committed
add vexp jitcode of size 8
test=develop
1 parent 2354409 commit 1e06a32

File tree

7 files changed

+241
-88
lines changed

7 files changed

+241
-88
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,132 @@ void ReluJitCode::generate() {
151151
}
152152
ret();
153153
}
154+
155+
bool VExpJitCode::init(int d) {
156+
return MayIUse(avx) && d == 8; // only 8 yet
157+
}
158+
159+
#define ALIGN32 __attribute__((aligned(32)))
160+
#define EXP_HIG 88.3762626647949f
161+
#define EXP_LOW -88.3762626647949f
162+
#define CEPHES_LOG2EF 1.44269504088896341
163+
#define CEPHES_EXP_C1 0.693359375
164+
#define CEPHES_EXP_C2 -2.12194440e-4
165+
#define CEPHES_EXP_P0 1.9875691500E-4
166+
#define CEPHES_EXP_P1 1.3981999507E-3
167+
#define CEPHES_EXP_P2 8.3334519073E-3
168+
#define CEPHES_EXP_P3 4.1665795894E-2
169+
#define CEPHES_EXP_P4 1.6666665459E-1
170+
#define CEPHES_EXP_P5 5.0000001201E-1
171+
172+
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
173+
174+
#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float)
175+
#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float)
176+
#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float)
177+
#define OFFSET_EXP_LOG2EF 4 * AVX_FLOAT_BLOCK * sizeof(float)
178+
#define OFFSET_EXP_C1 5 * AVX_FLOAT_BLOCK * sizeof(float)
179+
#define OFFSET_EXP_C2 6 * AVX_FLOAT_BLOCK * sizeof(float)
180+
#define OFFSET_EXP_P0 7 * AVX_FLOAT_BLOCK * sizeof(float)
181+
#define OFFSET_EXP_P1 8 * AVX_FLOAT_BLOCK * sizeof(float)
182+
#define OFFSET_EXP_P2 9 * AVX_FLOAT_BLOCK * sizeof(float)
183+
#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float)
184+
#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float)
185+
#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float)
186+
187+
static const float exp_float_consts[] ALIGN32 = {
188+
REPEAT_8TIMES(1.f), REPEAT_8TIMES(0.5f),
189+
REPEAT_8TIMES(EXP_HIG), REPEAT_8TIMES(EXP_LOW),
190+
REPEAT_8TIMES(CEPHES_LOG2EF), REPEAT_8TIMES(CEPHES_EXP_C1),
191+
REPEAT_8TIMES(CEPHES_EXP_C2), REPEAT_8TIMES(CEPHES_EXP_P0),
192+
REPEAT_8TIMES(CEPHES_EXP_P1), REPEAT_8TIMES(CEPHES_EXP_P2),
193+
REPEAT_8TIMES(CEPHES_EXP_P3), REPEAT_8TIMES(CEPHES_EXP_P4),
194+
REPEAT_8TIMES(CEPHES_EXP_P5)};
195+
196+
static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
197+
static int g_tmp_mem[16] ALIGN32 = {0};
198+
199+
void VExpJitCode::generate() {
200+
preCode();
201+
// push some?
202+
// in: ymm0, out: ymm1
203+
// use ymm 0~5 (and ymm 14~15 if avx only)
204+
int offset = 0;
205+
vmovups(ymm_src, ptr[param1 + offset]);
206+
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
207+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
208+
vminps(ymm_src, ymm_src, ymm_tmp);
209+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
210+
vmaxps(ymm_src, ymm_src, ymm_tmp);
211+
// express exp(x) as exp(g + n*log(2))
212+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
213+
vmulps(ymm_fx, ymm_src, ymm_tmp);
214+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
215+
vaddps(ymm_fx, ymm_fx, ymm_tmp);
216+
vroundps(ymm_fy, ymm_fx, 0x01);
217+
// if greater, substract 1
218+
vcmpgtps(ymm_mask, ymm_fy, ymm_fx);
219+
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
220+
vandps(ymm_mask, ymm_mask, ymm_tmp);
221+
vsubps(ymm_fx, ymm_fy, ymm_mask);
222+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
223+
vmulps(ymm_fy, ymm_fx, ymm_tmp);
224+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
225+
vmulps(ymm_z, ymm_fx, ymm_tmp); // ymm_z use same with mask
226+
vsubps(ymm_src, ymm_src, ymm_fy);
227+
vsubps(ymm_src, ymm_src, ymm_z);
228+
vmulps(ymm_z, ymm_src, ymm_src);
229+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
230+
vmulps(ymm_dst, ymm_src, ymm_tmp);
231+
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
232+
i += (AVX_FLOAT_BLOCK * sizeof(float))) {
233+
vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
234+
vaddps(ymm_dst, ymm_dst, ymm_tmp);
235+
vmulps(ymm_dst, ymm_dst, ymm_src);
236+
}
237+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
238+
vaddps(ymm_dst, ymm_dst, ymm_tmp);
239+
vmulps(ymm_dst, ymm_dst, ymm_z);
240+
vaddps(ymm_dst, ymm_dst, ymm_src);
241+
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
242+
vaddps(ymm_dst, ymm_dst, ymm_tmp);
243+
244+
// build 2^n
245+
ymm_t ymm_int = ymm_fx;
246+
vcvttps2dq(ymm_int, ymm_fx);
247+
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_int_0x7f));
248+
vmovdqa(ymm_tmp, ptr[reg_ptr_global]);
249+
if (MayIUse(avx2)) {
250+
vpaddd(ymm_int, ymm_int, ymm_tmp);
251+
vpslld(ymm_int, ymm_int, 23);
252+
} else if (MayIUse(avx)) {
253+
// use ymm_int, ymm_tmp and reg_ptr_global
254+
xmm_t xtmp1 = xmm_t(ymm_int); // or magic number should equal the ymm_int
255+
xmm_t xtmp2 = xmm_t(ymm_tmp); // or magic number should equal the ymm_tmp
256+
mov(reg_ptr_global, reinterpret_cast<size_t>(g_tmp_mem));
257+
vmovdqa(ptr[reg_ptr_global], ymm_int);
258+
vmovdqa(ptr[reg_ptr_global + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp);
259+
vpaddd(xtmp1, xtmp1, xtmp2);
260+
vpslld(xtmp1, xtmp1, 23);
261+
vmovdqa(ptr[reg_ptr_global], xtmp1);
262+
// next 128bits
263+
vmovdqa(xtmp1, ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)]);
264+
vmovdqa(xtmp2,
265+
ptr[reg_ptr_global +
266+
(AVX_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]);
267+
vpaddd(xtmp1, xtmp1, xtmp2);
268+
vpslld(xtmp1, xtmp1, 23);
269+
vmovdqa(ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)], xtmp1);
270+
// load out
271+
vmovdqa(ymm_int, ptr[reg_ptr_global]);
272+
}
273+
vmulps(ymm_dst, ymm_dst, ymm_int);
274+
vmovups(ptr[param2 + offset], ymm_dst);
275+
276+
// ret();
277+
postCode();
278+
}
279+
154280
} // namespace gen
155281
} // namespace jitkernel
156282
} // namespace math

paddle/fluid/operators/math/jit_code.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,30 @@ class ReluJitCode : public JitCode {
108108
ymm_t ymm_dst = ymm_t(1);
109109
};
110110

111+
class VExpJitCode : public JitCode {
112+
public:
113+
DECLARE_JIT_CODE(VExpJitCode);
114+
explicit VExpJitCode(int d, size_t code_size = 256 * 1024,
115+
void* code_ptr = nullptr)
116+
: JitCode(code_size, code_ptr), num_(d) {}
117+
static bool init(int d);
118+
void generate() override;
119+
120+
private:
121+
int num_;
122+
reg64_t param1{abi_param1};
123+
reg64_t param2{abi_param2};
124+
125+
reg64_t reg_ptr_global = rax;
126+
ymm_t ymm_src = ymm_t(0);
127+
ymm_t ymm_dst = ymm_t(1);
128+
ymm_t ymm_fx = ymm_t(2);
129+
ymm_t ymm_fy = ymm_t(3);
130+
ymm_t ymm_mask = ymm_t(4);
131+
ymm_t ymm_z = ymm_t(4);
132+
ymm_t ymm_tmp = ymm_t(5);
133+
};
134+
111135
} // namespace gen
112136
} // namespace jitkernel
113137
} // namespace math

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ template <typename T>
117117
class VExpKernel : public VActKernel<T> {
118118
public:
119119
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
120+
void (*Compute)(const T *, T *, int);
120121
};
121122

122123
template <typename T>

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@ limitations under the License. */
2525
#include "paddle/fluid/platform/dynload/mklml.h"
2626
#endif
2727

28-
#ifdef __AVX__
29-
#include <immintrin.h>
30-
#endif
31-
3228
namespace paddle {
3329
namespace operators {
3430
namespace math {
@@ -128,18 +124,11 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
128124

129125
#endif
130126

131-
#define DECLARE_STATIC_FUNC \
132-
static inline std::string name(int d) { \
133-
PADDLE_THROW("DType should be either float or double"); \
134-
} \
135-
static inline bool useJIT(int d) { return false; } \
136-
static inline bool useMKL(int d) { return false; }
137-
138127
/* VMUL JitKernel */
139128
template <typename T>
140129
class VMulKernelImpl : public VMulKernel<T> {
141130
public:
142-
DECLARE_STATIC_FUNC;
131+
JITKERNEL_DECLARE_STATIC_FUNC;
143132
explicit VMulKernelImpl(int d) : VMulKernel<T>() {
144133
#ifdef PADDLE_WITH_XBYAK
145134
if (useJIT(d)) {
@@ -191,7 +180,7 @@ bool VMulKernelImpl<double>::useMKL(int d) {
191180
template <typename T>
192181
class VAddKernelImpl : public VAddKernel<T> {
193182
public:
194-
DECLARE_STATIC_FUNC;
183+
JITKERNEL_DECLARE_STATIC_FUNC;
195184
explicit VAddKernelImpl(int d) : VAddKernel<T>() {
196185
#ifdef PADDLE_WITH_XBYAK
197186
if (useJIT(d)) {
@@ -241,7 +230,7 @@ bool VAddKernelImpl<double>::useMKL(int d) {
241230
template <typename T>
242231
class VAddReluKernelImpl : public VAddReluKernel<T> {
243232
public:
244-
DECLARE_STATIC_FUNC;
233+
JITKERNEL_DECLARE_STATIC_FUNC;
245234
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() {
246235
#ifdef PADDLE_WITH_XBYAK
247236
if (useJIT(d)) {
@@ -273,7 +262,7 @@ bool VAddReluKernelImpl<float>::useJIT(int d) {
273262
template <typename T>
274263
class VScalKernelImpl : public VScalKernel<T> {
275264
public:
276-
DECLARE_STATIC_FUNC;
265+
JITKERNEL_DECLARE_STATIC_FUNC;
277266
explicit VScalKernelImpl(int d) : VScalKernel<T>() {
278267
#ifdef PADDLE_WITH_XBYAK
279268
if (useJIT(d)) {
@@ -322,7 +311,7 @@ bool VScalKernelImpl<double>::useMKL(int d) {
322311
template <typename T>
323312
class VAddBiasKernelImpl : public VAddBiasKernel<T> {
324313
public:
325-
DECLARE_STATIC_FUNC;
314+
JITKERNEL_DECLARE_STATIC_FUNC;
326315
explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() {
327316
#ifdef PADDLE_WITH_XBYAK
328317
if (useJIT(d)) {
@@ -355,14 +344,14 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
355344
template <typename T>
356345
class VReluKernelImpl : public VReluKernel<T> {
357346
public:
358-
DECLARE_STATIC_FUNC;
347+
JITKERNEL_DECLARE_STATIC_FUNC;
359348
explicit VReluKernelImpl(int d) : VReluKernel<T>() {
360349
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
361350
#ifdef PADDLE_WITH_XBYAK
362351
if (useJIT(d)) {
363-
size_t sz = 96 /*init*/ +
364-
d / AVX_FLOAT_BLOCK * 4 /* instructions*/ *
365-
8 /*everage byte for each instruction*/;
352+
size_t sz = 96 /* init size */ +
353+
d / AVX_FLOAT_BLOCK * 4 /* instructions */ *
354+
8 /* average bytes for each instruction */;
366355
jitcode_.reset(new gen::ReluJitCode(d, sz > 4096 ? sz : 4096));
367356
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
368357
return;
@@ -388,8 +377,6 @@ bool VReluKernelImpl<float>::useJIT(int d) {
388377
}
389378
#endif
390379

391-
#undef DECLARE_STATIC_FUNC
392-
393380
REGISTER_JITKERNEL(vmul, VMulKernel);
394381
REGISTER_JITKERNEL(vadd, VAddKernel);
395382
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);

0 commit comments

Comments
 (0)