Skip to content

Commit 2a00969

Browse files
committed
optimize lstm jitkernel keq8
test=develop
1 parent f2adaf1 commit 2a00969

File tree

2 files changed

+111
-2
lines changed

2 files changed

+111
-2
lines changed

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,6 @@ endif()
7777
cc_test(concat_test SRCS concat_test.cc DEPS concat)
7878
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
7979
cc_library(jit_kernel_exp SRCS jit_kernel_exp.cc DEPS cpu_info cblas activation_functions)
80-
cc_library(jit_kernel SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_lstm.cc DEPS cpu_info cblas jit_kernel_exp)
80+
cc_library(jit_kernel_lstm SRCS jit_kernel_lstm.cc DEPS cpu_info cblas activation_functions)
81+
cc_library(jit_kernel SRCS jit_kernel.cc jit_kernel_blas.cc DEPS cpu_info cblas jit_kernel_exp jit_kernel_lstm)
8182
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)

paddle/fluid/operators/math/jit_kernel_lstm.cc

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616
#include <string>
1717
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
1818
#include "paddle/fluid/platform/enforce.h"
19+
#include "paddle/fluid/platform/macros.h"
1920

2021
#ifdef __AVX__
2122
#include <immintrin.h>
@@ -24,10 +25,63 @@ limitations under the License. */
2425
namespace paddle {
2526
namespace operators {
2627
namespace math {
27-
namespace jitkernel {
28+
#ifdef __AVX__
29+
namespace detail {
30+
__m256 Exp(__m256 a);
31+
} // namespace detail
32+
#endif
2833

34+
namespace jitkernel {
2935
namespace jit = platform::jit;
3036

37+
#ifdef __AVX__
38+
typedef enum { kSigmoid, kRelu, kTanh, kIdentity } act_type;
39+
40+
class AVXAct {
41+
public:
42+
virtual ~AVXAct() = default;
43+
virtual __m256 Compute(__m256 x) const = 0;
44+
};
45+
46+
template <act_type type>
47+
class AVXActImpl : public AVXAct {
48+
public:
49+
__m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); }
50+
};
51+
52+
template <>
53+
__m256 AVXActImpl<kSigmoid>::Compute(__m256 x) const {
54+
__m256 ones = _mm256_set1_ps(1.0f);
55+
x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN));
56+
x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX));
57+
x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x);
58+
x = detail::Exp(x);
59+
x = _mm256_add_ps(ones, x);
60+
return _mm256_div_ps(ones, x);
61+
}
62+
63+
template <>
64+
__m256 AVXActImpl<kTanh>::Compute(__m256 x) const {
65+
__m256 ones = _mm256_set1_ps(1.0f);
66+
x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x);
67+
x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT));
68+
x = detail::Exp(x);
69+
x = _mm256_add_ps(ones, x);
70+
x = _mm256_div_ps(_mm256_set1_ps(2.0f), x);
71+
return _mm256_sub_ps(x, ones);
72+
}
73+
74+
template <>
75+
__m256 AVXActImpl<kRelu>::Compute(__m256 x) const {
76+
return _mm256_max_ps(x, _mm256_setzero_ps());
77+
}
78+
79+
template <>
80+
__m256 AVXActImpl<kIdentity>::Compute(__m256 x) const {
81+
return x;
82+
}
83+
#endif
84+
3185
/* LSTM JitKernel */
3286
template <typename T, jit::cpu_isa_t isa, jit_block>
3387
class LSTMKernelImpl : public LSTMKernel<T> {
@@ -61,6 +115,23 @@ class LSTMKernelImpl : public LSTMKernel<T> {
61115
act_cell_d_ = GetActKernel(act_cell, d);
62116
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
63117
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
118+
#ifdef __AVX__
119+
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> {
120+
if (type == "sigmoid") {
121+
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid>());
122+
} else if (type == "relu") {
123+
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu>());
124+
} else if (type == "tanh") {
125+
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh>());
126+
} else if (type == "identity" || type == "") {
127+
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity>());
128+
}
129+
PADDLE_THROW("Not support type: %s", type);
130+
};
131+
avx_act_gate_ = GetAVXAct(act_gate);
132+
avx_act_cand_ = GetAVXAct(act_cand);
133+
avx_act_cell_ = GetAVXAct(act_cell);
134+
#endif
64135
}
65136

66137
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht) const override {
@@ -83,8 +154,44 @@ class LSTMKernelImpl : public LSTMKernel<T> {
83154
std::shared_ptr<const VActKernel<T>> act_gate_3d_, act_cand_d_, act_cell_d_;
84155
std::shared_ptr<const VMulKernel<T>> vmul_d_;
85156
std::shared_ptr<const VAddKernel<T>> vadd_d_;
157+
#ifdef __AVX__
158+
std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_cand_, avx_act_cell_;
159+
#endif
86160
};
87161

162+
#define INTRI8_FLOAT(isa) \
163+
template <> \
164+
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
165+
float* gates, const float* ct_1, float* ct, float* ht) const { \
166+
/* gates: W_ch, W_ih, W_fh, W_oh */ \
167+
__m256 c, i, f, o; \
168+
c = _mm256_loadu_ps(gates); \
169+
i = _mm256_loadu_ps(gates + 8); \
170+
f = _mm256_loadu_ps(gates + 16); \
171+
o = _mm256_loadu_ps(gates + 24); \
172+
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
173+
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
174+
i = _mm256_loadu_ps(ct_1); \
175+
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
176+
f = _mm256_add_ps(c, f); \
177+
_mm256_storeu_ps(ct, f); \
178+
/* H_t = act_cell(C_t) * ogated */ \
179+
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
180+
_mm256_storeu_ps(ht, o); \
181+
}
182+
183+
// TODO(TJ): optimize keq16
184+
185+
#ifdef __AVX__
186+
INTRI8_FLOAT(jit::avx);
187+
#endif
188+
#ifdef __AVX2__
189+
INTRI8_FLOAT(jit::avx2);
190+
#endif
191+
#ifdef __AVX512F__
192+
INTRI8_FLOAT(jit::avx512f);
193+
#endif
194+
88195
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
89196
template <> \
90197
std::shared_ptr<const ker_class<ker_dtype>> \
@@ -104,6 +211,7 @@ class LSTMKernelImpl : public LSTMKernel<T> {
104211
REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
105212
JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL);
106213

214+
#undef INTRI8_FLOAT
107215
#undef JITKERNEL_DECLARE_LSTM
108216
#undef JITKERNEL_KEY_LSTM
109217
#undef JITKERNEL_NEW_LSTM_IMPL

0 commit comments

Comments
 (0)