Skip to content

Commit ba3eaed

Browse files
committed
exp support all size
1 parent d239801 commit ba3eaed

File tree

3 files changed

+113
-14
lines changed

3 files changed

+113
-14
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 103 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ void VXXJitCode::generate() {
8181
}
8282
if (rest >= 2) {
8383
if (scalar_index_ != 1) {
84-
vmovups(xmm_src1, ptr[param1 + offset]);
84+
vmovq(xmm_src1, ptr[param1 + offset]);
8585
}
8686
if (scalar_index_ != 2) {
87-
vmovups(xmm_src2, ptr[param2 + offset]);
87+
vmovq(xmm_src2, ptr[param2 + offset]);
8888
}
8989
if (type_ == operand_type::mul) {
9090
vmulps(xmm_dst, xmm_src1, xmm_src2);
@@ -100,10 +100,10 @@ void VXXJitCode::generate() {
100100
}
101101
if (rest > 0) {
102102
if (scalar_index_ != 1) {
103-
vmovups(xmm_src1, ptr[param1 + offset]);
103+
vmovss(xmm_src1, ptr[param1 + offset]);
104104
}
105105
if (scalar_index_ != 2) {
106-
vmovups(xmm_src2, ptr[param2 + offset]);
106+
vmovss(xmm_src2, ptr[param2 + offset]);
107107
}
108108
if (type_ == operand_type::mul) {
109109
vmulss(xmm_dst, xmm_src1, xmm_src2);
@@ -179,7 +179,7 @@ bool VActJitCode::init(int d, operand_type type) {
179179
return ok;
180180
} else if (type == operand_type::exp) {
181181
// exp is slower than mkl when d >= 256
182-
return ok && d % 8 == 0 && d < 256;
182+
return ok; //&& d % 4 == 0 && d < 256;
183183
} else {
184184
// TODO(TJ): support more
185185
return ok && d % 8 == 0;
@@ -190,6 +190,10 @@ void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) {
190190
vmaxps(ymm_dst, ymm_zero, ymm_src);
191191
}
192192

193+
void VActJitCode::relu_xmm(xmm_t& xmm_dst, xmm_t& xmm_src, xmm_t& xmm_zero) {
194+
vmaxps(xmm_dst, xmm_zero, xmm_src);
195+
}
196+
193197
void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
194198
int fy_idx, int mask_idx, int tmp_idx) {
195199
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
@@ -271,6 +275,65 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
271275
pop(reg_ptr_global);
272276
}
273277

278+
void VActJitCode::exp_xmm(xmm_t& ymm_dst, xmm_t& ymm_src, int fx_idx,
279+
int fy_idx, int mask_idx, int tmp_idx) {
280+
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
281+
// check all idx can not equal
282+
xmm_t ymm_fx = xmm_t(fx_idx);
283+
xmm_t ymm_fy = xmm_t(fy_idx);
284+
xmm_t ymm_mask = xmm_t(mask_idx);
285+
xmm_t ymm_tmp = xmm_t(tmp_idx);
286+
reg64_t reg_ptr_global = rax;
287+
push(reg_ptr_global);
288+
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
289+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
290+
vminps(ymm_src, ymm_src, ymm_tmp);
291+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
292+
vmaxps(ymm_src, ymm_src, ymm_tmp);
293+
// express exp(x) as exp(g + n*log(2))
294+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
295+
vmulps(ymm_fx, ymm_src, ymm_tmp);
296+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
297+
vaddps(ymm_fx, ymm_fx, ymm_tmp);
298+
vroundps(ymm_fy, ymm_fx, 0x01);
299+
// if greater, substract 1
300+
vcmpgtps(ymm_mask, ymm_fy, ymm_fx);
301+
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
302+
vandps(ymm_mask, ymm_mask, ymm_tmp);
303+
vsubps(ymm_fx, ymm_fy, ymm_mask);
304+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
305+
vmulps(ymm_fy, ymm_fx, ymm_tmp);
306+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
307+
xmm_t ymm_z = xmm_t(ymm_mask.getIdx());
308+
vmulps(ymm_z, ymm_fx, ymm_tmp);
309+
vsubps(ymm_src, ymm_src, ymm_fy);
310+
vsubps(ymm_src, ymm_src, ymm_z);
311+
vmulps(ymm_z, ymm_src, ymm_src);
312+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
313+
vmulps(ymm_dst, ymm_src, ymm_tmp);
314+
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
315+
i += (YMM_FLOAT_BLOCK * sizeof(float))) {
316+
vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
317+
vaddps(ymm_dst, ymm_dst, ymm_tmp);
318+
vmulps(ymm_dst, ymm_dst, ymm_src);
319+
}
320+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
321+
vaddps(ymm_dst, ymm_dst, ymm_tmp);
322+
vmulps(ymm_dst, ymm_dst, ymm_z);
323+
vaddps(ymm_dst, ymm_dst, ymm_src);
324+
vmovaps(ymm_tmp, ptr[reg_ptr_global]);
325+
vaddps(ymm_dst, ymm_dst, ymm_tmp);
326+
// build 2^n
327+
xmm_t ymm_int = ymm_fx;
328+
vcvttps2dq(ymm_int, ymm_fx);
329+
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_int_0x7f));
330+
vmovdqa(ymm_tmp, ptr[reg_ptr_global]);
331+
vpaddd(ymm_int, ymm_int, ymm_tmp);
332+
vpslld(ymm_int, ymm_int, 23);
333+
vmulps(ymm_dst, ymm_dst, ymm_int);
334+
pop(reg_ptr_global);
335+
}
336+
274337
void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
275338
int fy_idx, int mask_idx, int tmp_idx) {
276339
// y = 1 / (1 + e^-x)
@@ -343,29 +406,58 @@ void VActJitCode::generate() {
343406
vmovups(ptr[param2 + offset], ymm_dst);
344407
offset += sizeof(float) * YMM_FLOAT_BLOCK;
345408
}
346-
if (type_ != operand_type::relu) {
409+
if (type_ != operand_type::relu && type_ != operand_type::exp) {
347410
// TODO(TJ): remove me
348411
ret();
349412
return;
350413
}
351414
int rest = num_ % YMM_FLOAT_BLOCK;
352415
if (rest >= 4) {
353416
vmovups(xmm_src, ptr[param1 + offset]);
354-
vmaxps(xmm_dst, xmm_zero, xmm_src);
417+
switch (type_) {
418+
case operand_type::relu:
419+
relu_xmm(xmm_dst, xmm_src, xmm_zero);
420+
break;
421+
case operand_type::exp:
422+
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
423+
break;
424+
default:
425+
break;
426+
}
355427
vmovups(ptr[param2 + offset], xmm_dst);
356428
offset += sizeof(float) * 4;
357429
rest -= 4;
358430
}
359431
if (rest >= 2) {
360-
vmovups(xmm_src, ptr[param1 + offset]);
361-
vmaxps(xmm_dst, xmm_zero, xmm_src);
432+
vmovq(xmm_src, ptr[param1 + offset]);
433+
switch (type_) {
434+
case operand_type::relu:
435+
relu_xmm(xmm_dst, xmm_src, xmm_zero);
436+
break;
437+
case operand_type::exp:
438+
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
439+
break;
440+
default:
441+
break;
442+
}
362443
vmovq(ptr[param2 + offset], xmm_dst);
363444
offset += sizeof(float) * 2;
364445
rest -= 2;
365446
}
366447
if (rest > 0) {
367-
vmovups(xmm_src, ptr[param1 + offset]);
368-
vmaxps(xmm_dst, xmm_zero, xmm_src);
448+
// vmovups();
449+
vmovss(xmm_src, ptr[param1 + offset]);
450+
451+
switch (type_) {
452+
case operand_type::relu:
453+
relu_xmm(xmm_dst, xmm_src, xmm_zero);
454+
break;
455+
case operand_type::exp:
456+
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
457+
break;
458+
default:
459+
break;
460+
}
369461
vmovss(ptr[param2 + offset], xmm_dst);
370462
}
371463
ret();

paddle/fluid/operators/math/jit_code.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,17 @@ class VActJitCode : public JitCode {
127127
void generate() override;
128128

129129
protected:
130-
// compute relu with ymm
130+
// compute relu with ymm, xmm
131131
void relu_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src,
132132
const Xbyak::Ymm& zero);
133+
void relu_xmm(const Xbyak::Xmm& dst, const Xbyak::Xmm& src,
134+
const Xbyak::Xmm& zero);
133135

134-
// compute exp with ymm
136+
// compute exp with ymm, xmm
135137
void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
136138
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
139+
void exp_xmm(const Xbyak::Xmm& dst, const Xbyak::Xmm& src, int fx_idx = 2,
140+
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
137141

138142
// compute sigmoid with ymm
139143
void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,

paddle/fluid/operators/math/jit_kernel_test.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ limitations under the License. */
3333

3434
constexpr int repeat = 20000;
3535

36+
// TODO(TJ): benchmark and test should be seperated,
37+
// benchmark should verify more sizes
38+
3639
inline double GetCurrentUS() {
3740
struct timeval time;
3841
gettimeofday(&time, NULL);
@@ -156,7 +159,7 @@ void vexp_mkl(const int n, const float* x, float* y) {
156159

157160
TEST(JitKernel, vexp) {
158161
namespace jit = paddle::operators::math::jitkernel;
159-
for (int d : {7, 8, 15, 16, 30, 128, 256}) {
162+
for (int d : {7, 8, 12, 15, 16, 20, 30, 128, 256}) {
160163
std::vector<float> x(d);
161164
std::vector<float> zref(d), ztgt(d);
162165
RandomVec<float>(d, x.data(), -2.f, 2.f);

0 commit comments

Comments
 (0)