Skip to content

Commit 4dbdfa6

Browse files
committed
sigmoid and tanh support all size
test=develop
1 parent ccb8963 commit 4dbdfa6

File tree

2 files changed

+54
-63
lines changed

2 files changed

+54
-63
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 10 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -132,56 +132,8 @@ const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
132132
int g_tmp_mem[16] ALIGN32 = {0};
133133

134134
bool VActJitCode::init(int d, operand_type type) {
135-
bool ok = MayIUse(avx);
136-
if (type == operand_type::relu || type == operand_type::exp) {
137-
// TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
138-
return ok;
139-
} else {
140-
// TODO(TJ): support more
141-
return ok && d % 8 == 0;
142-
}
143-
}
144-
145-
void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
146-
int fy_idx, int mask_idx, int tmp_idx) {
147-
// y = 1 / (1 + e^-x)
148-
ymm_t ymm_tmp = ymm_t(tmp_idx);
149-
reg64_t reg_ptr_global = rax;
150-
push(reg_ptr_global);
151-
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
152-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
153-
vminps(ymm_src, ymm_src, ymm_tmp);
154-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]);
155-
vmaxps(ymm_src, ymm_src, ymm_tmp);
156-
vxorps(ymm_tmp, ymm_tmp, ymm_tmp);
157-
vsubps(ymm_src, ymm_tmp, ymm_src);
158-
exp_jmm<ymm_t>(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
159-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
160-
vaddps(ymm_dst, ymm_dst, ymm_tmp);
161-
vdivps(ymm_dst, ymm_tmp, ymm_dst);
162-
pop(reg_ptr_global);
163-
}
164-
165-
void VActJitCode::tanh_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
166-
int fy_idx, int mask_idx, int tmp_idx) {
167-
// y = 2 / (1 + e^(-2x)) - 1
168-
ymm_t ymm_tmp = ymm_t(tmp_idx);
169-
ymm_t ymm_zero = ymm_t(mask_idx);
170-
reg64_t reg_ptr_global = rax;
171-
push(reg_ptr_global);
172-
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
173-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
174-
vxorps(ymm_zero, ymm_zero, ymm_zero);
175-
vsubps(ymm_tmp, ymm_zero, ymm_tmp);
176-
vmulps(ymm_src, ymm_src, ymm_tmp);
177-
exp_jmm<ymm_t>(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
178-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
179-
vaddps(ymm_dst, ymm_dst, ymm_tmp);
180-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
181-
vdivps(ymm_dst, ymm_tmp, ymm_dst);
182-
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
183-
vsubps(ymm_dst, ymm_dst, ymm_tmp);
184-
pop(reg_ptr_global);
135+
// TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
136+
return MayIUse(avx);
185137
}
186138

187139
void VActJitCode::generate() {
@@ -201,10 +153,10 @@ void VActJitCode::generate() {
201153
exp_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
202154
break;
203155
case operand_type::sigmoid:
204-
sigmoid_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
156+
sigmoid_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
205157
break;
206158
case operand_type::tanh:
207-
tanh_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
159+
tanh_jmm<ymm_t>(ymm_dst, ymm_src, 2, 3, 4, 5);
208160
break;
209161
case operand_type::identity:
210162
break;
@@ -214,11 +166,6 @@ void VActJitCode::generate() {
214166
vmovups(ptr[param2 + offset], ymm_dst);
215167
offset += sizeof(float) * YMM_FLOAT_BLOCK;
216168
}
217-
if (type_ != operand_type::relu && type_ != operand_type::exp) {
218-
// TODO(TJ): remove me
219-
ret();
220-
return;
221-
}
222169
int rest = num_ % YMM_FLOAT_BLOCK;
223170
int block = XMM_FLOAT_BLOCK;
224171
while (rest > 0) {
@@ -236,6 +183,12 @@ void VActJitCode::generate() {
236183
case operand_type::exp:
237184
exp_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
238185
break;
186+
case operand_type::sigmoid:
187+
sigmoid_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
188+
break;
189+
case operand_type::tanh:
190+
tanh_jmm<xmm_t>(xmm_dst, xmm_src, 2, 3, 4, 5);
191+
break;
239192
default:
240193
break;
241194
}

paddle/fluid/operators/math/jit_code.h

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,13 +263,51 @@ class VActJitCode : public JitCode {
263263
pop(reg_ptr_global);
264264
}
265265

266-
// compute sigmoid with ymm
267-
void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
268-
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
266+
// compute sigmoid with ymm, xmm
267+
template <typename JMM>
268+
void sigmoid_jmm(JMM& dst, JMM& src, int fx_idx = 2, // NOLINT
269+
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5) {
270+
// y = 1 / (1 + e^-x)
271+
JMM jmm_tmp = JMM(tmp_idx);
272+
reg64_t reg_ptr_global = rax;
273+
push(reg_ptr_global);
274+
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
275+
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
276+
vminps(src, src, jmm_tmp);
277+
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]);
278+
vmaxps(src, src, jmm_tmp);
279+
vxorps(jmm_tmp, jmm_tmp, jmm_tmp);
280+
vsubps(src, jmm_tmp, src);
281+
exp_jmm<JMM>(dst, src, fx_idx, fy_idx, mask_idx, tmp_idx);
282+
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
283+
vaddps(dst, dst, jmm_tmp);
284+
vdivps(dst, jmm_tmp, dst);
285+
pop(reg_ptr_global);
286+
}
269287

270-
// compute tanh with ymm
271-
void tanh_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
272-
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
288+
// compute tanh with ymm, xmm
289+
template <typename JMM>
290+
void tanh_jmm(JMM& dst, JMM& src, int fx_idx = 2, int fy_idx = 3, // NOLINT
291+
int mask_idx = 4, int tmp_idx = 5) {
292+
// y = 2 / (1 + e^(-2x)) - 1
293+
JMM jmm_tmp = JMM(tmp_idx);
294+
JMM jmm_zero = JMM(mask_idx);
295+
reg64_t reg_ptr_global = rax;
296+
push(reg_ptr_global);
297+
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
298+
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
299+
vxorps(jmm_zero, jmm_zero, jmm_zero);
300+
vsubps(jmm_tmp, jmm_zero, jmm_tmp);
301+
vmulps(src, src, jmm_tmp);
302+
exp_jmm<JMM>(dst, src, fx_idx, fy_idx, mask_idx, tmp_idx);
303+
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
304+
vaddps(dst, dst, jmm_tmp);
305+
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
306+
vdivps(dst, jmm_tmp, dst);
307+
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
308+
vsubps(dst, dst, jmm_tmp);
309+
pop(reg_ptr_global);
310+
}
273311

274312
protected:
275313
int num_;

0 commit comments

Comments
 (0)