Skip to content

Commit ee2a7f1

Browse files
committed
refine exp and fix error on avx
test=develop
1 parent 1e06a32 commit ee2a7f1

File tree

2 files changed

+15
-19
lines changed

2 files changed

+15
-19
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,8 @@ static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
197197
static int g_tmp_mem[16] ALIGN32 = {0};
198198

199199
void VExpJitCode::generate() {
200-
preCode();
201-
// push some?
202200
// in: ymm0, out: ymm1
203-
// use ymm 0~5 (and ymm 14~15 if avx only)
201+
// use ymm 0~5, rax
204202
int offset = 0;
205203
vmovups(ymm_src, ptr[param1 + offset]);
206204
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
@@ -222,7 +220,8 @@ void VExpJitCode::generate() {
222220
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
223221
vmulps(ymm_fy, ymm_fx, ymm_tmp);
224222
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
225-
vmulps(ymm_z, ymm_fx, ymm_tmp); // ymm_z use same with mask
223+
ymm_t ymm_z = ymm_t(ymm_mask.getIdx());
224+
vmulps(ymm_z, ymm_fx, ymm_tmp);
226225
vsubps(ymm_src, ymm_src, ymm_fy);
227226
vsubps(ymm_src, ymm_src, ymm_z);
228227
vmulps(ymm_z, ymm_src, ymm_src);
@@ -240,7 +239,6 @@ void VExpJitCode::generate() {
240239
vaddps(ymm_dst, ymm_dst, ymm_src);
241240
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
242241
vaddps(ymm_dst, ymm_dst, ymm_tmp);
243-
244242
// build 2^n
245243
ymm_t ymm_int = ymm_fx;
246244
vcvttps2dq(ymm_int, ymm_fx);
@@ -250,31 +248,30 @@ void VExpJitCode::generate() {
250248
vpaddd(ymm_int, ymm_int, ymm_tmp);
251249
vpslld(ymm_int, ymm_int, 23);
252250
} else if (MayIUse(avx)) {
253-
// use ymm_int, ymm_tmp and reg_ptr_global
254-
xmm_t xtmp1 = xmm_t(ymm_int); // or magic number should equal the ymm_int
255-
xmm_t xtmp2 = xmm_t(ymm_tmp); // or magic number should equal the ymm_tmp
256-
mov(reg_ptr_global, reinterpret_cast<size_t>(g_tmp_mem));
257-
vmovdqa(ptr[reg_ptr_global], ymm_int);
258-
vmovdqa(ptr[reg_ptr_global + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp);
251+
xmm_t xtmp1 = xmm_t(ymm_int.getIdx());
252+
xmm_t xtmp2 = xmm_t(ymm_tmp.getIdx());
253+
reg64_t reg_ptr_tmp = reg_ptr_global;
254+
mov(reg_ptr_tmp, reinterpret_cast<size_t>(g_tmp_mem));
255+
vmovdqa(ptr[reg_ptr_tmp], ymm_int);
256+
vmovdqa(ptr[reg_ptr_tmp + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp);
259257
vpaddd(xtmp1, xtmp1, xtmp2);
260258
vpslld(xtmp1, xtmp1, 23);
261-
vmovdqa(ptr[reg_ptr_global], xtmp1);
259+
vmovdqa(ptr[reg_ptr_tmp], xtmp1);
262260
// next 128bits
263-
vmovdqa(xtmp1, ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)]);
261+
vmovdqa(xtmp1, ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)]);
264262
vmovdqa(xtmp2,
265-
ptr[reg_ptr_global +
263+
ptr[reg_ptr_tmp +
266264
(AVX_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]);
267265
vpaddd(xtmp1, xtmp1, xtmp2);
268266
vpslld(xtmp1, xtmp1, 23);
269-
vmovdqa(ptr[reg_ptr_global + 4 /*xmm float block*/ * sizeof(float)], xtmp1);
267+
vmovdqa(ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)], xtmp1);
270268
// load out
271-
vmovdqa(ymm_int, ptr[reg_ptr_global]);
269+
vmovdqa(ymm_int, ptr[reg_ptr_tmp]);
272270
}
273271
vmulps(ymm_dst, ymm_dst, ymm_int);
274272
vmovups(ptr[param2 + offset], ymm_dst);
275273

276-
// ret();
277-
postCode();
274+
ret();
278275
}
279276

280277
} // namespace gen

paddle/fluid/operators/math/jit_code.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ class VExpJitCode : public JitCode {
128128
ymm_t ymm_fx = ymm_t(2);
129129
ymm_t ymm_fy = ymm_t(3);
130130
ymm_t ymm_mask = ymm_t(4);
131-
ymm_t ymm_z = ymm_t(4);
132131
ymm_t ymm_tmp = ymm_t(5);
133132
};
134133

0 commit comments

Comments
 (0)