Skip to content

Commit 6a15907

Browse files
committed
add vtanh jitcode of size 8
1 parent 046374b commit 6a15907

File tree

5 files changed

+153
-166
lines changed

5 files changed

+153
-166
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -168,24 +168,26 @@ void ReluJitCode::generate() {
168168
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
169169

170170
#define OFFSET_EXP_ONE 0 * AVX_FLOAT_BLOCK * sizeof(float)
171-
#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float)
172-
#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float)
173-
#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float)
174-
#define OFFSET_EXP_LOG2EF 4 * AVX_FLOAT_BLOCK * sizeof(float)
175-
#define OFFSET_EXP_C1 5 * AVX_FLOAT_BLOCK * sizeof(float)
176-
#define OFFSET_EXP_C2 6 * AVX_FLOAT_BLOCK * sizeof(float)
177-
#define OFFSET_EXP_P0 7 * AVX_FLOAT_BLOCK * sizeof(float)
178-
#define OFFSET_EXP_P1 8 * AVX_FLOAT_BLOCK * sizeof(float)
179-
#define OFFSET_EXP_P2 9 * AVX_FLOAT_BLOCK * sizeof(float)
180-
#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float)
181-
#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float)
182-
#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float)
183-
#define OFFSET_EXP_MAX_INPUT 13 * AVX_FLOAT_BLOCK * sizeof(float)
184-
#define OFFSET_SIGMOID_MAX 14 * AVX_FLOAT_BLOCK * sizeof(float)
185-
#define OFFSET_SIGMOID_MIN 15 * AVX_FLOAT_BLOCK * sizeof(float)
171+
#define OFFSET_EXP_TWO 1 * AVX_FLOAT_BLOCK * sizeof(float)
172+
#define OFFSET_EXP_0P5 2 * AVX_FLOAT_BLOCK * sizeof(float)
173+
#define OFFSET_EXP_HIG 3 * AVX_FLOAT_BLOCK * sizeof(float)
174+
#define OFFSET_EXP_LOW 4 * AVX_FLOAT_BLOCK * sizeof(float)
175+
#define OFFSET_EXP_LOG2EF 5 * AVX_FLOAT_BLOCK * sizeof(float)
176+
#define OFFSET_EXP_C1 6 * AVX_FLOAT_BLOCK * sizeof(float)
177+
#define OFFSET_EXP_C2 7 * AVX_FLOAT_BLOCK * sizeof(float)
178+
#define OFFSET_EXP_P0 8 * AVX_FLOAT_BLOCK * sizeof(float)
179+
#define OFFSET_EXP_P1 9 * AVX_FLOAT_BLOCK * sizeof(float)
180+
#define OFFSET_EXP_P2 10 * AVX_FLOAT_BLOCK * sizeof(float)
181+
#define OFFSET_EXP_P3 11 * AVX_FLOAT_BLOCK * sizeof(float)
182+
#define OFFSET_EXP_P4 12 * AVX_FLOAT_BLOCK * sizeof(float)
183+
#define OFFSET_EXP_P5 13 * AVX_FLOAT_BLOCK * sizeof(float)
184+
#define OFFSET_EXP_MAX_INPUT 14 * AVX_FLOAT_BLOCK * sizeof(float)
185+
#define OFFSET_SIGMOID_MAX 15 * AVX_FLOAT_BLOCK * sizeof(float)
186+
#define OFFSET_SIGMOID_MIN 16 * AVX_FLOAT_BLOCK * sizeof(float)
186187

187188
static const float exp_float_consts[] ALIGN32 = {
188189
REPEAT_8TIMES(1.f),
190+
REPEAT_8TIMES(2.f),
189191
REPEAT_8TIMES(0.5f),
190192
REPEAT_8TIMES(EXP_HIG),
191193
REPEAT_8TIMES(EXP_LOW),
@@ -216,6 +218,7 @@ void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
216218
ymm_t ymm_fy = ymm_t(3);
217219
ymm_t ymm_mask = ymm_t(4);
218220
ymm_t ymm_tmp = ymm_t(5);
221+
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
219222
push(reg_ptr_global);
220223
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
221224
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
@@ -327,6 +330,40 @@ void VSigmoidJitCode::generate() {
327330
ret();
328331
}
329332

333+
bool VTanhJitCode::init(int d) {
334+
return MayIUse(avx) && d == 8; // only 8 yet
335+
}
336+
337+
void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
338+
// y = 2 / (1 + e^(-2x)) - 1
339+
// use ymm2, ymm3
340+
reg64_t reg_ptr_global = rax;
341+
ymm_t ymm_tmp = ymm_t(2);
342+
ymm_t ymm_zero = ymm_t(3);
343+
push(reg_ptr_global);
344+
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
345+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
346+
vxorps(ymm_zero, ymm_zero, ymm_zero);
347+
vsubps(ymm_tmp, ymm_zero, ymm_tmp);
348+
vmulps(ymm_src, ymm_src, ymm_tmp);
349+
exp_ymm(ymm_src, ymm_dst);
350+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
351+
vaddps(ymm_dst, ymm_dst, ymm_tmp);
352+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
353+
vdivps(ymm_dst, ymm_tmp, ymm_dst);
354+
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
355+
vsubps(ymm_dst, ymm_dst, ymm_tmp);
356+
pop(reg_ptr_global);
357+
}
358+
359+
void VTanhJitCode::generate() {
360+
int offset = 0;
361+
vmovups(ymm_src, ptr[param1 + offset]);
362+
vtanh_ymm(ymm_src, ymm_dst);
363+
vmovups(ptr[param2 + offset], ymm_dst);
364+
ret();
365+
}
366+
330367
} // namespace gen
331368
} // namespace jitkernel
332369
} // namespace math

paddle/fluid/operators/math/jit_code.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,26 @@ class VSigmoidJitCode : public VExpJitCode {
149149
ymm_t ymm_dst = ymm_t(1);
150150
};
151151

152+
class VTanhJitCode : public VExpJitCode {
153+
public:
154+
DECLARE_JIT_CODE(VTanhJitCode);
155+
explicit VTanhJitCode(int d, size_t code_size = 256 * 1024,
156+
void* code_ptr = nullptr)
157+
: VExpJitCode(d, code_size, code_ptr), num_(d) {}
158+
static bool init(int d);
159+
void generate() override;
160+
161+
// compute sigmoid with ymm
162+
void vtanh_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst);
163+
164+
private:
165+
int num_;
166+
reg64_t param1{abi_param1};
167+
reg64_t param2{abi_param2};
168+
ymm_t ymm_src = ymm_t(0);
169+
ymm_t ymm_dst = ymm_t(1);
170+
};
171+
152172
} // namespace gen
153173
} // namespace jitkernel
154174
} // namespace math

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ template <typename T>
132132
class VTanhKernel : public VActKernel<T> {
133133
public:
134134
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
135+
void (*Compute)(const T *, T *, int);
135136
};
136137

137138
template <typename T>

paddle/fluid/operators/math/jit_kernel_exp.cc

Lines changed: 79 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ void VExpRefer(const T* x, T* y, int n) {
4545

4646
template <typename T>
4747
void VSigmoidRefer(const T* x, T* y, int n) {
48+
// y = 1 / (1 + e^-x)
4849
const T min = SIGMOID_THRESHOLD_MIN;
4950
const T max = SIGMOID_THRESHOLD_MAX;
5051
for (int i = 0; i < n; ++i) {
@@ -53,6 +54,18 @@ void VSigmoidRefer(const T* x, T* y, int n) {
5354
}
5455
}
5556

57+
template <typename T>
58+
void VTanhRefer(const T* x, T* y, int n) {
59+
// y = 2 * sigmoid(2x) - 1
60+
for (int i = 0; i < n; ++i) {
61+
y[i] = static_cast<T>(2) * x[i];
62+
}
63+
VSigmoidRefer(y, y, n);
64+
for (int i = 0; i < n; ++i) {
65+
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
66+
}
67+
}
68+
5669
#ifdef PADDLE_WITH_MKLML
5770
template <typename T>
5871
void VExpMKL(const T* x, T* y, int n);
@@ -80,6 +93,17 @@ void VSigmoidMKL(const T* x, T* y, int n) {
8093
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
8194
}
8295
}
96+
97+
template <typename T>
98+
void VTanhMKL(const T* x, T* y, int n) {
99+
for (int i = 0; i < n; ++i) {
100+
y[i] = static_cast<T>(2) * x[i];
101+
}
102+
VSigmoidMKL(y, y, n);
103+
for (int i = 0; i < n; ++i) {
104+
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
105+
}
106+
}
83107
#endif
84108

85109
/* VExp JitKernel */
@@ -189,8 +213,63 @@ bool VSigmoidKernelImpl<double>::useMKL(int d) {
189213
}
190214
#endif
191215

216+
/* VTanh JitKernel */
217+
template <typename T>
218+
class VTanhKernelImpl : public VTanhKernel<T> {
219+
public:
220+
JITKERNEL_DECLARE_STATIC_FUNC;
221+
explicit VTanhKernelImpl(int d) : VTanhKernel<T>() {
222+
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
223+
#ifdef PADDLE_WITH_XBYAK
224+
if (useJIT(d)) {
225+
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change
226+
jitcode_.reset(new gen::VTanhJitCode(d, sz > 4096 ? sz : 4096));
227+
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
228+
return;
229+
}
230+
#endif
231+
232+
#ifdef PADDLE_WITH_MKLML
233+
// strictly it's a better impl with MKL, then is refer
234+
if (useMKL(d)) {
235+
this->Compute = VTanhMKL<T>;
236+
return;
237+
}
238+
#endif
239+
this->Compute = VTanhRefer<T>;
240+
}
241+
void ComputeDeprecated(const T* x, T* y) const override {
242+
VTanhRefer(x, y, this->num_);
243+
}
244+
#ifdef PADDLE_WITH_XBYAK
245+
246+
private:
247+
std::unique_ptr<gen::VTanhJitCode> jitcode_{nullptr};
248+
#endif
249+
};
250+
251+
#ifdef PADDLE_WITH_XBYAK
252+
template <>
253+
bool VTanhKernelImpl<float>::useJIT(int d) {
254+
return gen::VTanhJitCode::init(d);
255+
}
256+
#endif
257+
258+
#ifdef PADDLE_WITH_MKLML
259+
template <>
260+
bool VTanhKernelImpl<float>::useMKL(int d) {
261+
return d > 512;
262+
}
263+
264+
template <>
265+
bool VTanhKernelImpl<double>::useMKL(int d) {
266+
return true;
267+
}
268+
#endif
269+
192270
REGISTER_JITKERNEL(vexp, VExpKernel);
193271
REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel);
272+
REGISTER_JITKERNEL(vtanh, VTanhKernel);
194273

195274
namespace detail {
196275

@@ -337,156 +416,6 @@ __m256 ExpAVX2(__m256 x) {
337416
#endif
338417

339418
} // namespace detail
340-
341-
#define INTRI_SIGMOID(tmp, min, max, expisa) \
342-
tmp = _mm256_max_ps(tmp, min); \
343-
tmp = _mm256_min_ps(tmp, max); \
344-
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \
345-
tmp = expisa(tmp); \
346-
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
347-
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp)
348-
#undef INTRI_VSIGMOID
349-
350-
/* VTanh JitKernel */
351-
template <typename T, jit::cpu_isa_t isa, jit_block>
352-
class VTanhKernelImpl : public VTanhKernel<T> {
353-
public:
354-
explicit VTanhKernelImpl(int d) : VTanhKernel<T>() {
355-
this->num_ = d;
356-
vscal_ = KernelPool::Instance().template Get<VScalKernel<T>>(d);
357-
vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<T>>(d);
358-
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d);
359-
}
360-
void ComputeDeprecated(const T* x, T* y) const override {
361-
const T a = static_cast<T>(2), b = static_cast<T>(-1);
362-
vscal_->Compute(&a, x, y, this->num_);
363-
vsigmoid_->ComputeDeprecated(y, y);
364-
vscal_->Compute(&a, y, y, this->num_);
365-
vaddbias_->Compute(&b, y, y, this->num_);
366-
}
367-
368-
private:
369-
std::shared_ptr<const VScalKernel<T>> vscal_;
370-
std::shared_ptr<const VSigmoidKernel<T>> vsigmoid_;
371-
std::shared_ptr<const VAddBiasKernel<T>> vaddbias_;
372-
};
373-
374-
#define INTRI_VTANH(tmp, expisa) \
375-
tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \
376-
tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \
377-
tmp = expisa(tmp); \
378-
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
379-
tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \
380-
tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f))
381-
382-
#define INTRI8_FLOAT(isa, expisa) \
383-
template <> \
384-
void VTanhKernelImpl<float, isa, kEQ8>::ComputeDeprecated(const float* x, \
385-
float* y) const { \
386-
__m256 tmp = _mm256_loadu_ps(x); \
387-
INTRI_VTANH(tmp, expisa); \
388-
_mm256_storeu_ps(y, tmp); \
389-
}
390-
391-
#define INTRI16_FLOAT(isa, expisa) \
392-
template <> \
393-
void VTanhKernelImpl<float, isa, kEQ16>::ComputeDeprecated(const float* x, \
394-
float* y) const { \
395-
__m256 tmp0 = _mm256_loadu_ps(x); \
396-
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
397-
INTRI_VTANH(tmp0, expisa); \
398-
INTRI_VTANH(tmp1, expisa); \
399-
_mm256_storeu_ps(y, tmp0); \
400-
_mm256_storeu_ps(y + 8, tmp1); \
401-
}
402-
403-
#define INTRI_GT8LT16_FLOAT(isa, expisa) \
404-
template <> \
405-
VTanhKernelImpl<float, isa, kGT8LT16>::VTanhKernelImpl(int d) \
406-
: VTanhKernel<float>() { \
407-
this->num_ = d; \
408-
this->end_ = AVX_FLOAT_BLOCK; \
409-
this->rest_ = d - this->end_; \
410-
vscal_ = \
411-
KernelPool::Instance().template Get<VScalKernel<float>>(this->rest_); \
412-
vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<float>>( \
413-
this->rest_); \
414-
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<float>>( \
415-
this->rest_); \
416-
} \
417-
template <> \
418-
void VTanhKernelImpl<float, isa, kGT8LT16>::ComputeDeprecated( \
419-
const float* x, float* y) const { \
420-
__m256 tmp = _mm256_loadu_ps(x); \
421-
INTRI_VTANH(tmp, expisa); \
422-
_mm256_storeu_ps(y, tmp); \
423-
x += AVX_FLOAT_BLOCK; \
424-
y += AVX_FLOAT_BLOCK; \
425-
const float a = 2.f, b = -1.f; \
426-
vscal_->Compute(&a, x, y, this->num_); \
427-
vsigmoid_->ComputeDeprecated(y, y); \
428-
vscal_->Compute(&a, y, y, this->num_); \
429-
vaddbias_->Compute(&b, y, y, this->num_); \
430-
}
431-
432-
#define INTRI_GT16_FLOAT(isa, expisa) \
433-
template <> \
434-
VTanhKernelImpl<float, isa, kGT16>::VTanhKernelImpl(int d) \
435-
: VTanhKernel<float>() { \
436-
this->num_ = d; \
437-
this->rest_ = d % AVX_FLOAT_BLOCK; \
438-
this->end_ = d - this->rest_; \
439-
vscal_ = \
440-
KernelPool::Instance().template Get<VScalKernel<float>>(this->rest_); \
441-
vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<float>>( \
442-
this->rest_); \
443-
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<float>>( \
444-
this->rest_); \
445-
} \
446-
template <> \
447-
void VTanhKernelImpl<float, isa, kGT16>::ComputeDeprecated(const float* x, \
448-
float* y) const { \
449-
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
450-
__m256 tmp = _mm256_loadu_ps(x + i); \
451-
INTRI_VTANH(tmp, expisa); \
452-
_mm256_storeu_ps(y + i, tmp); \
453-
} \
454-
x += this->end_; \
455-
y += this->end_; \
456-
const float a = 2.f, b = -1.f; \
457-
vscal_->Compute(&a, x, y, this->num_); \
458-
vsigmoid_->ComputeDeprecated(y, y); \
459-
vscal_->Compute(&a, y, y, this->num_); \
460-
vaddbias_->Compute(&b, y, y, this->num_); \
461-
}
462-
463-
#ifdef __AVX__
464-
INTRI8_FLOAT(jit::avx, detail::ExpAVX);
465-
INTRI16_FLOAT(jit::avx, detail::ExpAVX);
466-
INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX);
467-
INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX);
468-
#endif
469-
#ifdef __AVX2__
470-
INTRI8_FLOAT(jit::avx2, detail::ExpAVX2);
471-
INTRI16_FLOAT(jit::avx2, detail::ExpAVX2);
472-
// maybe use avx at gt8lt16 and gt16
473-
#endif
474-
#ifdef __AVX512F__
475-
INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2);
476-
INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
477-
// maybe use avx at gt8lt16 and gt16
478-
#endif
479-
480-
#undef INTRI8_FLOAT
481-
#undef INTRI16_FLOAT
482-
#undef INTRI_GT8LT16_FLOAT
483-
#undef INTRI_GT16_FLOAT
484-
#undef INTRI_VTANH
485-
486-
REGISTER_JITKERNEL_DEPRECATED(vtanh, VTanhKernel);
487-
488-
#undef JITKERNEL_NEW_ACT_IMPL
489-
490419
} // namespace jitkernel
491420
} // namespace math
492421
} // namespace operators

0 commit comments

Comments
 (0)