Skip to content

Commit ccb8963

Browse files
committed
refine exp jitcode with all size
test=develop
1 parent d3eae8f commit ccb8963

File tree

3 files changed

+153
-203
lines changed

3 files changed

+153
-203
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 24 additions & 199 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/math/jit_code.h"
16-
#include "paddle/fluid/operators/math/jit_kernel.h"
17-
#include "paddle/fluid/platform/cpu_info.h"
16+
#include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me
1817

1918
namespace paddle {
2019
namespace operators {
@@ -111,60 +110,26 @@ void VXXJitCode::generate() {
111110
ret();
112111
}
113112

114-
#define ALIGN32 __attribute__((aligned(32)))
115-
#define EXP_HIG 88.3762626647949f
116-
#define EXP_LOW -88.3762626647949f
117-
#define CEPHES_LOG2EF 1.44269504088896341
118-
#define CEPHES_EXP_C1 0.693359375
119-
#define CEPHES_EXP_C2 -2.12194440e-4
120-
#define CEPHES_EXP_P0 1.9875691500E-4
121-
#define CEPHES_EXP_P1 1.3981999507E-3
122-
#define CEPHES_EXP_P2 8.3334519073E-3
123-
#define CEPHES_EXP_P3 4.1665795894E-2
124-
#define CEPHES_EXP_P4 1.6666665459E-1
125-
#define CEPHES_EXP_P5 5.0000001201E-1
113+
const float exp_float_consts[] ALIGN32 = {REPEAT_8TIMES(1.f),
114+
REPEAT_8TIMES(2.f),
115+
REPEAT_8TIMES(0.5f),
116+
REPEAT_8TIMES(EXP_HIG),
117+
REPEAT_8TIMES(EXP_LOW),
118+
REPEAT_8TIMES(CEPHES_LOG2EF),
119+
REPEAT_8TIMES(CEPHES_EXP_C1),
120+
REPEAT_8TIMES(CEPHES_EXP_C2),
121+
REPEAT_8TIMES(CEPHES_EXP_P0),
122+
REPEAT_8TIMES(CEPHES_EXP_P1),
123+
REPEAT_8TIMES(CEPHES_EXP_P2),
124+
REPEAT_8TIMES(CEPHES_EXP_P3),
125+
REPEAT_8TIMES(CEPHES_EXP_P4),
126+
REPEAT_8TIMES(CEPHES_EXP_P5),
127+
REPEAT_8TIMES(EXP_MAX_INPUT),
128+
REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX),
129+
REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)};
126130

127-
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
128-
129-
#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float)
130-
#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float)
131-
#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float)
132-
#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float)
133-
#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float)
134-
#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float)
135-
#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float)
136-
#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float)
137-
#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float)
138-
#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float)
139-
#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float)
140-
#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float)
141-
#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float)
142-
#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float)
143-
#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float)
144-
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
145-
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
146-
147-
static const float exp_float_consts[] ALIGN32 = {
148-
REPEAT_8TIMES(1.f),
149-
REPEAT_8TIMES(2.f),
150-
REPEAT_8TIMES(0.5f),
151-
REPEAT_8TIMES(EXP_HIG),
152-
REPEAT_8TIMES(EXP_LOW),
153-
REPEAT_8TIMES(CEPHES_LOG2EF),
154-
REPEAT_8TIMES(CEPHES_EXP_C1),
155-
REPEAT_8TIMES(CEPHES_EXP_C2),
156-
REPEAT_8TIMES(CEPHES_EXP_P0),
157-
REPEAT_8TIMES(CEPHES_EXP_P1),
158-
REPEAT_8TIMES(CEPHES_EXP_P2),
159-
REPEAT_8TIMES(CEPHES_EXP_P3),
160-
REPEAT_8TIMES(CEPHES_EXP_P4),
161-
REPEAT_8TIMES(CEPHES_EXP_P5),
162-
REPEAT_8TIMES(EXP_MAX_INPUT),
163-
REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX),
164-
REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)};
165-
166-
static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
167-
static int g_tmp_mem[16] ALIGN32 = {0};
131+
const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
132+
int g_tmp_mem[16] ALIGN32 = {0};
168133

169134
bool VActJitCode::init(int d, operand_type type) {
170135
bool ok = MayIUse(avx);
@@ -177,146 +142,6 @@ bool VActJitCode::init(int d, operand_type type) {
177142
}
178143
}
179144

180-
void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
181-
int fy_idx, int mask_idx, int tmp_idx) {
182-
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
183-
// check all idx can not equal
184-
ymm_t ymm_fx = ymm_t(fx_idx);
185-
ymm_t ymm_fy = ymm_t(fy_idx);
186-
ymm_t ymm_mask = ymm_t(mask_idx);
187-
ymm_t ymm_tmp = ymm_t(tmp_idx);
188-
reg64_t reg_ptr_global = rax;
189-
push(reg_ptr_global);
190-
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
191-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
192-
vminps(ymm_src, ymm_src, ymm_tmp);
193-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
194-
vmaxps(ymm_src, ymm_src, ymm_tmp);
195-
// express exp(x) as exp(g + n*log(2))
196-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
197-
vmulps(ymm_fx, ymm_src, ymm_tmp);
198-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
199-
vaddps(ymm_fx, ymm_fx, ymm_tmp);
200-
vroundps(ymm_fy, ymm_fx, 0x01);
201-
// if greater, substract 1
202-
vcmpgtps(ymm_mask, ymm_fy, ymm_fx);
203-
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
204-
vandps(ymm_mask, ymm_mask, ymm_tmp);
205-
vsubps(ymm_fx, ymm_fy, ymm_mask);
206-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
207-
vmulps(ymm_fy, ymm_fx, ymm_tmp);
208-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
209-
ymm_t ymm_z = ymm_t(ymm_mask.getIdx());
210-
vmulps(ymm_z, ymm_fx, ymm_tmp);
211-
vsubps(ymm_src, ymm_src, ymm_fy);
212-
vsubps(ymm_src, ymm_src, ymm_z);
213-
vmulps(ymm_z, ymm_src, ymm_src);
214-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
215-
vmulps(ymm_dst, ymm_src, ymm_tmp);
216-
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
217-
i += (YMM_FLOAT_BLOCK * sizeof(float))) {
218-
vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
219-
vaddps(ymm_dst, ymm_dst, ymm_tmp);
220-
vmulps(ymm_dst, ymm_dst, ymm_src);
221-
}
222-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
223-
vaddps(ymm_dst, ymm_dst, ymm_tmp);
224-
vmulps(ymm_dst, ymm_dst, ymm_z);
225-
vaddps(ymm_dst, ymm_dst, ymm_src);
226-
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
227-
vaddps(ymm_dst, ymm_dst, ymm_tmp);
228-
// build 2^n
229-
ymm_t ymm_int = ymm_fx;
230-
vcvttps2dq(ymm_int, ymm_fx);
231-
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_int_0x7f));
232-
vmovdqa(ymm_tmp, ptr[reg_ptr_global]);
233-
if (MayIUse(avx2)) {
234-
vpaddd(ymm_int, ymm_int, ymm_tmp);
235-
vpslld(ymm_int, ymm_int, 23);
236-
} else if (MayIUse(avx)) {
237-
xmm_t xtmp1 = xmm_t(ymm_int.getIdx());
238-
xmm_t xtmp2 = xmm_t(ymm_tmp.getIdx());
239-
reg64_t reg_ptr_tmp = reg_ptr_global;
240-
mov(reg_ptr_tmp, reinterpret_cast<size_t>(g_tmp_mem));
241-
vmovdqa(ptr[reg_ptr_tmp], ymm_int);
242-
vmovdqa(ptr[reg_ptr_tmp + YMM_FLOAT_BLOCK * sizeof(float)], ymm_tmp);
243-
vpaddd(xtmp1, xtmp1, xtmp2);
244-
vpslld(xtmp1, xtmp1, 23);
245-
vmovdqa(ptr[reg_ptr_tmp], xtmp1);
246-
// next 128bits
247-
vmovdqa(xtmp1, ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)]);
248-
vmovdqa(xtmp2,
249-
ptr[reg_ptr_tmp +
250-
(YMM_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]);
251-
vpaddd(xtmp1, xtmp1, xtmp2);
252-
vpslld(xtmp1, xtmp1, 23);
253-
vmovdqa(ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)], xtmp1);
254-
// load out
255-
vmovdqa(ymm_int, ptr[reg_ptr_tmp]);
256-
}
257-
vmulps(ymm_dst, ymm_dst, ymm_int);
258-
pop(reg_ptr_global);
259-
}
260-
261-
void VActJitCode::exp_xmm(xmm_t& ymm_dst, xmm_t& ymm_src, int fx_idx,
262-
int fy_idx, int mask_idx, int tmp_idx) {
263-
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
264-
// check all idx can not equal
265-
xmm_t ymm_fx = xmm_t(fx_idx);
266-
xmm_t ymm_fy = xmm_t(fy_idx);
267-
xmm_t ymm_mask = xmm_t(mask_idx);
268-
xmm_t ymm_tmp = xmm_t(tmp_idx);
269-
reg64_t reg_ptr_global = rax;
270-
push(reg_ptr_global);
271-
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
272-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
273-
vminps(ymm_src, ymm_src, ymm_tmp);
274-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
275-
vmaxps(ymm_src, ymm_src, ymm_tmp);
276-
// express exp(x) as exp(g + n*log(2))
277-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
278-
vmulps(ymm_fx, ymm_src, ymm_tmp);
279-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
280-
vaddps(ymm_fx, ymm_fx, ymm_tmp);
281-
vroundps(ymm_fy, ymm_fx, 0x01);
282-
// if greater, substract 1
283-
vcmpgtps(ymm_mask, ymm_fy, ymm_fx);
284-
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
285-
vandps(ymm_mask, ymm_mask, ymm_tmp);
286-
vsubps(ymm_fx, ymm_fy, ymm_mask);
287-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
288-
vmulps(ymm_fy, ymm_fx, ymm_tmp);
289-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
290-
xmm_t ymm_z = xmm_t(ymm_mask.getIdx());
291-
vmulps(ymm_z, ymm_fx, ymm_tmp);
292-
vsubps(ymm_src, ymm_src, ymm_fy);
293-
vsubps(ymm_src, ymm_src, ymm_z);
294-
vmulps(ymm_z, ymm_src, ymm_src);
295-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
296-
vmulps(ymm_dst, ymm_src, ymm_tmp);
297-
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
298-
i += (YMM_FLOAT_BLOCK * sizeof(float))) {
299-
vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
300-
vaddps(ymm_dst, ymm_dst, ymm_tmp);
301-
vmulps(ymm_dst, ymm_dst, ymm_src);
302-
}
303-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
304-
vaddps(ymm_dst, ymm_dst, ymm_tmp);
305-
vmulps(ymm_dst, ymm_dst, ymm_z);
306-
vaddps(ymm_dst, ymm_dst, ymm_src);
307-
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
308-
vaddps(ymm_dst, ymm_dst, ymm_tmp);
309-
// build 2^n
310-
xmm_t ymm_int = ymm_fx;
311-
vcvttps2dq(ymm_int, ymm_fx);
312-
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_int_0x7f));
313-
vmovdqa(ymm_tmp, ptr[reg_ptr_global]);
314-
vpaddd(ymm_int, ymm_int, ymm_tmp);
315-
vpslld(ymm_int, ymm_int, 23);
316-
vmulps(ymm_dst, ymm_dst, ymm_int);
317-
pop(reg_ptr_global);
318-
}
319-
320145
void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
321146
int fy_idx, int mask_idx, int tmp_idx) {
322147
// y = 1 / (1 + e^-x)
@@ -330,7 +155,7 @@ void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
330155
vmaxps(ymm_src, ymm_src, ymm_tmp);
331156
vxorps(ymm_tmp, ymm_tmp, ymm_tmp);
332157
vsubps(ymm_src, ymm_tmp, ymm_src);
333-
exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
158+
exp_jmm<ymm_t>(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
334159
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
335160
vaddps(ymm_dst, ymm_dst, ymm_tmp);
336161
vdivps(ymm_dst, ymm_tmp, ymm_dst);
@@ -349,7 +174,7 @@ void VActJitCode::tanh_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
349174
vxorps(ymm_zero, ymm_zero, ymm_zero);
350175
vsubps(ymm_tmp, ymm_zero, ymm_tmp);
351176
vmulps(ymm_src, ymm_src, ymm_tmp);
352-
exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
177+
exp_jmm<ymm_t>(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
353178
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
354179
vaddps(ymm_dst, ymm_dst, ymm_tmp);
355180
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
@@ -373,7 +198,7 @@ void VActJitCode::generate() {
373198
relu_jmm<ymm_t>(ymm_dst, ymm_src, ymm_zero);
374199
break;
375200
case operand_type::exp:
376-
exp_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
201+
exp_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
377202
break;
378203
case operand_type::sigmoid:
379204
sigmoid_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
@@ -409,7 +234,7 @@ void VActJitCode::generate() {
409234
relu_jmm<xmm_t>(xmm_dst, xmm_src, xmm_zero);
410235
break;
411236
case operand_type::exp:
412-
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
237+
exp_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
413238
break;
414239
default:
415240
break;

0 commit comments

Comments
 (0)