Skip to content

Commit 7f17e56

Browse files
authored
Merge pull request #14423 from tensor-tang/fea/jit/act
jitcode act relu, exp, sigmoid, tanh
2 parents 28bd5b7 + 1f00723 commit 7f17e56

File tree

10 files changed

+626
-503
lines changed

10 files changed

+626
-503
lines changed

paddle/fluid/operators/math/cpu_vec.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ namespace math {
3333
#define SIGMOID_THRESHOLD_MIN -40.0
3434
#define SIGMOID_THRESHOLD_MAX 13.0
3535

36-
#define AVX_FLOAT_BLOCK 8
36+
#define YMM_FLOAT_BLOCK 8
3737
#define AVX_DOUBLE_BLOCK 4
38-
#define AVX2_FLOAT_BLOCK 8
38+
#define YMM_FLOAT_BLOCK 8
3939
#define AVX2_DOUBLE_BLOCK 4
40-
#define AVX512_FLOAT_BLOCK 16
40+
#define ZMM_FLOAT_BLOCK 16
4141
#define AVX512_DOUBLE_BLOCK 8
4242

4343
template <typename T>
@@ -88,7 +88,7 @@ template <>
8888
inline void vec_scal<float, platform::jit::avx>(const int n, const float a,
8989
const float* x, float* y) {
9090
#ifdef __AVX__
91-
constexpr int block = AVX_FLOAT_BLOCK;
91+
constexpr int block = YMM_FLOAT_BLOCK;
9292
if (n < block) {
9393
vec_scal<float, platform::jit::isa_any>(n, a, x, y);
9494
return;
@@ -142,7 +142,7 @@ template <>
142142
inline void vec_bias_sub<float, platform::jit::avx>(const int n, const float a,
143143
const float* x, float* y) {
144144
#ifdef __AVX__
145-
constexpr int block = AVX_FLOAT_BLOCK;
145+
constexpr int block = YMM_FLOAT_BLOCK;
146146
if (n < block) {
147147
vec_bias_sub<float, platform::jit::isa_any>(n, a, x, y);
148148
return;
@@ -200,7 +200,7 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x,
200200
const float* y, const float* z,
201201
float* out) {
202202
#ifdef __AVX__
203-
constexpr int block = AVX_FLOAT_BLOCK;
203+
constexpr int block = YMM_FLOAT_BLOCK;
204204
if (n < block) {
205205
vec_cross<float, platform::jit::isa_any>(n, x, y, z, out);
206206
return;
@@ -257,7 +257,7 @@ template <>
257257
inline void vec_add_bias<float, platform::jit::avx>(const int n, const float a,
258258
const float* x, float* y) {
259259
#ifdef __AVX__
260-
constexpr int block = AVX_FLOAT_BLOCK;
260+
constexpr int block = YMM_FLOAT_BLOCK;
261261
if (n < block) {
262262
vec_add_bias<float, platform::jit::isa_any>(n, a, x, y);
263263
return;
@@ -326,7 +326,7 @@ template <>
326326
inline void vec_sigmoid<float, platform::jit::avx>(const int n, const float* x,
327327
float* y) {
328328
#ifdef __AVX__
329-
constexpr int block = AVX_FLOAT_BLOCK;
329+
constexpr int block = YMM_FLOAT_BLOCK;
330330
if (n < block) {
331331
vec_sigmoid<float, platform::jit::isa_any>(n, x, y);
332332
return;
@@ -415,7 +415,7 @@ template <>
415415
inline void vec_relu<float, platform::jit::avx>(const int n, const float* x,
416416
float* y) {
417417
#ifdef __AVX__
418-
constexpr int block = AVX_FLOAT_BLOCK;
418+
constexpr int block = YMM_FLOAT_BLOCK;
419419
if (n < block * 4) {
420420
vec_relu<float, platform::jit::isa_any>(n, x, y);
421421
return;

paddle/fluid/operators/math/jit_code.cc

Lines changed: 230 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void VXXJitCode::generate() {
4141
} else if (scalar_index_ == 2) {
4242
vbroadcastss(ymm_src2, ptr[param2]);
4343
}
44-
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
44+
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
4545
if (scalar_index_ != 1) {
4646
vmovups(ymm_src1, ptr[param1 + offset]);
4747
}
@@ -57,9 +57,9 @@ void VXXJitCode::generate() {
5757
vmaxps(ymm_dst, ymm_zero, ymm_dst);
5858
}
5959
vmovups(ptr[param3 + offset], ymm_dst);
60-
offset += sizeof(float) * AVX_FLOAT_BLOCK;
60+
offset += sizeof(float) * YMM_FLOAT_BLOCK;
6161
}
62-
int rest = num_ % AVX_FLOAT_BLOCK;
62+
int rest = num_ % YMM_FLOAT_BLOCK;
6363
if (rest >= 4) {
6464
if (scalar_index_ != 1) {
6565
vmovups(xmm_src1, ptr[param1 + offset]);
@@ -118,18 +118,237 @@ void VXXJitCode::generate() {
118118
ret();
119119
}
120120

121-
bool ReluJitCode::init(int d) { return MayIUse(avx); }
121+
#define ALIGN32 __attribute__((aligned(32)))
122+
#define EXP_HIG 88.3762626647949f
123+
#define EXP_LOW -88.3762626647949f
124+
#define CEPHES_LOG2EF 1.44269504088896341
125+
#define CEPHES_EXP_C1 0.693359375
126+
#define CEPHES_EXP_C2 -2.12194440e-4
127+
#define CEPHES_EXP_P0 1.9875691500E-4
128+
#define CEPHES_EXP_P1 1.3981999507E-3
129+
#define CEPHES_EXP_P2 8.3334519073E-3
130+
#define CEPHES_EXP_P3 4.1665795894E-2
131+
#define CEPHES_EXP_P4 1.6666665459E-1
132+
#define CEPHES_EXP_P5 5.0000001201E-1
122133

123-
void ReluJitCode::generate() {
124-
int offset = 0;
134+
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
135+
136+
#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float)
137+
#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float)
138+
#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float)
139+
#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float)
140+
#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float)
141+
#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float)
142+
#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float)
143+
#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float)
144+
#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float)
145+
#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float)
146+
#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float)
147+
#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float)
148+
#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float)
149+
#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float)
150+
#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float)
151+
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
152+
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
153+
154+
static const float exp_float_consts[] ALIGN32 = {
155+
REPEAT_8TIMES(1.f),
156+
REPEAT_8TIMES(2.f),
157+
REPEAT_8TIMES(0.5f),
158+
REPEAT_8TIMES(EXP_HIG),
159+
REPEAT_8TIMES(EXP_LOW),
160+
REPEAT_8TIMES(CEPHES_LOG2EF),
161+
REPEAT_8TIMES(CEPHES_EXP_C1),
162+
REPEAT_8TIMES(CEPHES_EXP_C2),
163+
REPEAT_8TIMES(CEPHES_EXP_P0),
164+
REPEAT_8TIMES(CEPHES_EXP_P1),
165+
REPEAT_8TIMES(CEPHES_EXP_P2),
166+
REPEAT_8TIMES(CEPHES_EXP_P3),
167+
REPEAT_8TIMES(CEPHES_EXP_P4),
168+
REPEAT_8TIMES(CEPHES_EXP_P5),
169+
REPEAT_8TIMES(EXP_MAX_INPUT),
170+
REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX),
171+
REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)};
172+
173+
static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
174+
static int g_tmp_mem[16] ALIGN32 = {0};
175+
176+
bool VActJitCode::init(int d, operand_type type) {
177+
bool ok = MayIUse(avx);
178+
if (type == operand_type::relu) {
179+
return ok;
180+
} else if (type == operand_type::exp) {
181+
// exp is slower than mkl when d >= 256
182+
return ok && d % 8 == 0 && d < 256;
183+
} else {
184+
// TODO(TJ): support more
185+
return ok && d % 8 == 0;
186+
}
187+
}
188+
189+
void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) {
190+
vmaxps(ymm_dst, ymm_zero, ymm_src);
191+
}
192+
193+
void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
194+
int fy_idx, int mask_idx, int tmp_idx) {
195+
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
196+
// check all idx can not equal
197+
ymm_t ymm_fx = ymm_t(fx_idx);
198+
ymm_t ymm_fy = ymm_t(fy_idx);
199+
ymm_t ymm_mask = ymm_t(mask_idx);
200+
ymm_t ymm_tmp = ymm_t(tmp_idx);
201+
reg64_t reg_ptr_global = rax;
202+
push(reg_ptr_global);
203+
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
204+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
205+
vminps(ymm_src, ymm_src, ymm_tmp);
206+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
207+
vmaxps(ymm_src, ymm_src, ymm_tmp);
208+
// express exp(x) as exp(g + n*log(2))
209+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
210+
vmulps(ymm_fx, ymm_src, ymm_tmp);
211+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
212+
vaddps(ymm_fx, ymm_fx, ymm_tmp);
213+
vroundps(ymm_fy, ymm_fx, 0x01);
214+
// if greater, substract 1
215+
vcmpgtps(ymm_mask, ymm_fy, ymm_fx);
216+
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
217+
vandps(ymm_mask, ymm_mask, ymm_tmp);
218+
vsubps(ymm_fx, ymm_fy, ymm_mask);
219+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
220+
vmulps(ymm_fy, ymm_fx, ymm_tmp);
221+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
222+
ymm_t ymm_z = ymm_t(ymm_mask.getIdx());
223+
vmulps(ymm_z, ymm_fx, ymm_tmp);
224+
vsubps(ymm_src, ymm_src, ymm_fy);
225+
vsubps(ymm_src, ymm_src, ymm_z);
226+
vmulps(ymm_z, ymm_src, ymm_src);
227+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
228+
vmulps(ymm_dst, ymm_src, ymm_tmp);
229+
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
230+
i += (YMM_FLOAT_BLOCK * sizeof(float))) {
231+
vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
232+
vaddps(ymm_dst, ymm_dst, ymm_tmp);
233+
vmulps(ymm_dst, ymm_dst, ymm_src);
234+
}
235+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
236+
vaddps(ymm_dst, ymm_dst, ymm_tmp);
237+
vmulps(ymm_dst, ymm_dst, ymm_z);
238+
vaddps(ymm_dst, ymm_dst, ymm_src);
239+
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
240+
vaddps(ymm_dst, ymm_dst, ymm_tmp);
241+
// build 2^n
242+
ymm_t ymm_int = ymm_fx;
243+
vcvttps2dq(ymm_int, ymm_fx);
244+
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_int_0x7f));
245+
vmovdqa(ymm_tmp, ptr[reg_ptr_global]);
246+
if (MayIUse(avx2)) {
247+
vpaddd(ymm_int, ymm_int, ymm_tmp);
248+
vpslld(ymm_int, ymm_int, 23);
249+
} else if (MayIUse(avx)) {
250+
xmm_t xtmp1 = xmm_t(ymm_int.getIdx());
251+
xmm_t xtmp2 = xmm_t(ymm_tmp.getIdx());
252+
reg64_t reg_ptr_tmp = reg_ptr_global;
253+
mov(reg_ptr_tmp, reinterpret_cast<size_t>(g_tmp_mem));
254+
vmovdqa(ptr[reg_ptr_tmp], ymm_int);
255+
vmovdqa(ptr[reg_ptr_tmp + YMM_FLOAT_BLOCK * sizeof(float)], ymm_tmp);
256+
vpaddd(xtmp1, xtmp1, xtmp2);
257+
vpslld(xtmp1, xtmp1, 23);
258+
vmovdqa(ptr[reg_ptr_tmp], xtmp1);
259+
// next 128bits
260+
vmovdqa(xtmp1, ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)]);
261+
vmovdqa(xtmp2,
262+
ptr[reg_ptr_tmp +
263+
(YMM_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]);
264+
vpaddd(xtmp1, xtmp1, xtmp2);
265+
vpslld(xtmp1, xtmp1, 23);
266+
vmovdqa(ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)], xtmp1);
267+
// load out
268+
vmovdqa(ymm_int, ptr[reg_ptr_tmp]);
269+
}
270+
vmulps(ymm_dst, ymm_dst, ymm_int);
271+
pop(reg_ptr_global);
272+
}
273+
274+
void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
275+
int fy_idx, int mask_idx, int tmp_idx) {
276+
// y = 1 / (1 + e^-x)
277+
ymm_t ymm_tmp = ymm_t(tmp_idx);
278+
reg64_t reg_ptr_global = rax;
279+
push(reg_ptr_global);
280+
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
281+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
282+
vminps(ymm_src, ymm_src, ymm_tmp);
283+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]);
284+
vmaxps(ymm_src, ymm_src, ymm_tmp);
285+
vxorps(ymm_tmp, ymm_tmp, ymm_tmp);
286+
vsubps(ymm_src, ymm_tmp, ymm_src);
287+
exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
288+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
289+
vaddps(ymm_dst, ymm_dst, ymm_tmp);
290+
vdivps(ymm_dst, ymm_tmp, ymm_dst);
291+
pop(reg_ptr_global);
292+
}
293+
294+
void VActJitCode::tanh_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
295+
int fy_idx, int mask_idx, int tmp_idx) {
296+
// y = 2 / (1 + e^(-2x)) - 1
297+
ymm_t ymm_tmp = ymm_t(tmp_idx);
298+
ymm_t ymm_zero = ymm_t(mask_idx);
299+
reg64_t reg_ptr_global = rax;
300+
push(reg_ptr_global);
301+
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
302+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
125303
vxorps(ymm_zero, ymm_zero, ymm_zero);
126-
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
304+
vsubps(ymm_tmp, ymm_zero, ymm_tmp);
305+
vmulps(ymm_src, ymm_src, ymm_tmp);
306+
exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
307+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
308+
vaddps(ymm_dst, ymm_dst, ymm_tmp);
309+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
310+
vdivps(ymm_dst, ymm_tmp, ymm_dst);
311+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
312+
vsubps(ymm_dst, ymm_dst, ymm_tmp);
313+
pop(reg_ptr_global);
314+
}
315+
316+
void VActJitCode::generate() {
317+
xmm_t xmm_zero = xmm_t(2);
318+
ymm_t ymm_zero = ymm_t(2);
319+
if (type_ == operand_type::relu) {
320+
vxorps(ymm_zero, ymm_zero, ymm_zero);
321+
}
322+
int offset = 0;
323+
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
127324
vmovups(ymm_src, ptr[param1 + offset]);
128-
vmaxps(ymm_dst, ymm_zero, ymm_src);
325+
switch (type_) {
326+
case operand_type::relu:
327+
relu_ymm(ymm_dst, ymm_src, ymm_zero);
328+
break;
329+
case operand_type::exp:
330+
exp_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
331+
break;
332+
case operand_type::sigmoid:
333+
sigmoid_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
334+
break;
335+
case operand_type::tanh:
336+
tanh_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
337+
break;
338+
case operand_type::identity:
339+
break;
340+
default:
341+
break;
342+
}
129343
vmovups(ptr[param2 + offset], ymm_dst);
130-
offset += sizeof(float) * AVX_FLOAT_BLOCK;
344+
offset += sizeof(float) * YMM_FLOAT_BLOCK;
345+
}
346+
if (type_ != operand_type::relu) {
347+
// TODO(TJ): remove me
348+
ret();
349+
return;
131350
}
132-
int rest = num_ % AVX_FLOAT_BLOCK;
351+
int rest = num_ % YMM_FLOAT_BLOCK;
133352
if (rest >= 4) {
134353
vmovups(xmm_src, ptr[param1 + offset]);
135354
vmaxps(xmm_dst, xmm_zero, xmm_src);
@@ -151,6 +370,7 @@ void ReluJitCode::generate() {
151370
}
152371
ret();
153372
}
373+
154374
} // namespace gen
155375
} // namespace jitkernel
156376
} // namespace math

0 commit comments

Comments
 (0)