Skip to content

Commit 046374b

Browse files
committed
add vsigmoid jitcode of size 8
1 parent ee2a7f1 commit 046374b

File tree

5 files changed

+177
-161
lines changed

5 files changed

+177
-161
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,6 @@ void ReluJitCode::generate() {
152152
ret();
153153
}
154154

155-
bool VExpJitCode::init(int d) {
156-
return MayIUse(avx) && d == 8; // only 8 yet
157-
}
158-
159155
#define ALIGN32 __attribute__((aligned(32)))
160156
#define EXP_HIG 88.3762626647949f
161157
#define EXP_LOW -88.3762626647949f
@@ -171,6 +167,7 @@ bool VExpJitCode::init(int d) {
171167

172168
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
173169

170+
#define OFFSET_EXP_ONE 0 * AVX_FLOAT_BLOCK * sizeof(float)
174171
#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float)
175172
#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float)
176173
#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float)
@@ -183,24 +180,43 @@ bool VExpJitCode::init(int d) {
183180
#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float)
184181
#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float)
185182
#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float)
183+
#define OFFSET_EXP_MAX_INPUT 13 * AVX_FLOAT_BLOCK * sizeof(float)
184+
#define OFFSET_SIGMOID_MAX 14 * AVX_FLOAT_BLOCK * sizeof(float)
185+
#define OFFSET_SIGMOID_MIN 15 * AVX_FLOAT_BLOCK * sizeof(float)
186186

187187
static const float exp_float_consts[] ALIGN32 = {
188-
REPEAT_8TIMES(1.f), REPEAT_8TIMES(0.5f),
189-
REPEAT_8TIMES(EXP_HIG), REPEAT_8TIMES(EXP_LOW),
190-
REPEAT_8TIMES(CEPHES_LOG2EF), REPEAT_8TIMES(CEPHES_EXP_C1),
191-
REPEAT_8TIMES(CEPHES_EXP_C2), REPEAT_8TIMES(CEPHES_EXP_P0),
192-
REPEAT_8TIMES(CEPHES_EXP_P1), REPEAT_8TIMES(CEPHES_EXP_P2),
193-
REPEAT_8TIMES(CEPHES_EXP_P3), REPEAT_8TIMES(CEPHES_EXP_P4),
194-
REPEAT_8TIMES(CEPHES_EXP_P5)};
188+
REPEAT_8TIMES(1.f),
189+
REPEAT_8TIMES(0.5f),
190+
REPEAT_8TIMES(EXP_HIG),
191+
REPEAT_8TIMES(EXP_LOW),
192+
REPEAT_8TIMES(CEPHES_LOG2EF),
193+
REPEAT_8TIMES(CEPHES_EXP_C1),
194+
REPEAT_8TIMES(CEPHES_EXP_C2),
195+
REPEAT_8TIMES(CEPHES_EXP_P0),
196+
REPEAT_8TIMES(CEPHES_EXP_P1),
197+
REPEAT_8TIMES(CEPHES_EXP_P2),
198+
REPEAT_8TIMES(CEPHES_EXP_P3),
199+
REPEAT_8TIMES(CEPHES_EXP_P4),
200+
REPEAT_8TIMES(CEPHES_EXP_P5),
201+
REPEAT_8TIMES(EXP_MAX_INPUT),
202+
REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX),
203+
REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)};
195204

196205
static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
197206
static int g_tmp_mem[16] ALIGN32 = {0};
198207

199-
void VExpJitCode::generate() {
200-
// in: ymm0, out: ymm1
201-
// use ymm 0~5, rax
202-
int offset = 0;
203-
vmovups(ymm_src, ptr[param1 + offset]);
208+
bool VExpJitCode::init(int d) {
209+
return MayIUse(avx) && d == 8; // only 8 yet
210+
}
211+
212+
void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
213+
// use reg rax and ymm 2~5
214+
reg64_t reg_ptr_global = rax;
215+
ymm_t ymm_fx = ymm_t(2);
216+
ymm_t ymm_fy = ymm_t(3);
217+
ymm_t ymm_mask = ymm_t(4);
218+
ymm_t ymm_tmp = ymm_t(5);
219+
push(reg_ptr_global);
204220
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
205221
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
206222
vminps(ymm_src, ymm_src, ymm_tmp);
@@ -269,8 +285,45 @@ void VExpJitCode::generate() {
269285
vmovdqa(ymm_int, ptr[reg_ptr_tmp]);
270286
}
271287
vmulps(ymm_dst, ymm_dst, ymm_int);
288+
pop(reg_ptr_global);
289+
}
290+
291+
void VExpJitCode::generate() {
292+
int offset = 0;
293+
vmovups(ymm_src, ptr[param1 + offset]);
294+
exp_ymm(ymm_src, ymm_dst);
272295
vmovups(ptr[param2 + offset], ymm_dst);
296+
ret();
297+
}
298+
299+
bool VSigmoidJitCode::init(int d) {
300+
return MayIUse(avx) && d == 8; // only 8 yet
301+
}
273302

303+
void VSigmoidJitCode::sigmoid_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
304+
// use ymm2
305+
reg64_t reg_ptr_global = rax;
306+
ymm_t ymm_tmp = ymm_t(2);
307+
push(reg_ptr_global);
308+
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
309+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
310+
vminps(ymm_src, ymm_src, ymm_tmp);
311+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]);
312+
vmaxps(ymm_src, ymm_src, ymm_tmp);
313+
vxorps(ymm_tmp, ymm_tmp, ymm_tmp);
314+
vsubps(ymm_src, ymm_tmp, ymm_src);
315+
exp_ymm(ymm_src, ymm_dst);
316+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
317+
vaddps(ymm_dst, ymm_dst, ymm_tmp);
318+
vdivps(ymm_dst, ymm_tmp, ymm_dst);
319+
pop(reg_ptr_global);
320+
}
321+
322+
void VSigmoidJitCode::generate() {
323+
int offset = 0;
324+
vmovups(ymm_src, ptr[param1 + offset]);
325+
sigmoid_ymm(ymm_src, ymm_dst);
326+
vmovups(ptr[param2 + offset], ymm_dst);
274327
ret();
275328
}
276329

paddle/fluid/operators/math/jit_code.h

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,36 @@ class VExpJitCode : public JitCode {
117117
static bool init(int d);
118118
void generate() override;
119119

120+
protected:
121+
// compute exp with ymm
122+
void exp_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst);
123+
120124
private:
121125
int num_;
122126
reg64_t param1{abi_param1};
123127
reg64_t param2{abi_param2};
128+
ymm_t ymm_src = ymm_t(0);
129+
ymm_t ymm_dst = ymm_t(1);
130+
};
124131

125-
reg64_t reg_ptr_global = rax;
132+
class VSigmoidJitCode : public VExpJitCode {
133+
public:
134+
DECLARE_JIT_CODE(VSigmoidJitCode);
135+
explicit VSigmoidJitCode(int d, size_t code_size = 256 * 1024,
136+
void* code_ptr = nullptr)
137+
: VExpJitCode(d, code_size, code_ptr), num_(d) {}
138+
static bool init(int d);
139+
void generate() override;
140+
141+
// compute sigmoid with ymm
142+
void sigmoid_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst);
143+
144+
private:
145+
int num_;
146+
reg64_t param1{abi_param1};
147+
reg64_t param2{abi_param2};
126148
ymm_t ymm_src = ymm_t(0);
127149
ymm_t ymm_dst = ymm_t(1);
128-
ymm_t ymm_fx = ymm_t(2);
129-
ymm_t ymm_fy = ymm_t(3);
130-
ymm_t ymm_mask = ymm_t(4);
131-
ymm_t ymm_tmp = ymm_t(5);
132150
};
133151

134152
} // namespace gen

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ namespace jitkernel {
2929
#define SIGMOID_THRESHOLD_MIN -40.0
3030
#define SIGMOID_THRESHOLD_MAX 13.0
3131
#define EXP_MAX_INPUT 40.0
32+
// TODO(TJ): change AVX_FLOAT_BLOCK to YMM_FLOAT_BLOCK
3233
#define AVX_FLOAT_BLOCK 8
3334
#define AVX2_FLOAT_BLOCK 8
3435
#define AVX512_FLOAT_BLOCK 16
@@ -124,6 +125,7 @@ template <typename T>
124125
class VSigmoidKernel : public VActKernel<T> {
125126
public:
126127
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
128+
void (*Compute)(const T *, T *, int);
127129
};
128130

129131
template <typename T>

0 commit comments

Comments
 (0)