@@ -16,6 +16,7 @@ limitations under the License. */
16
16
#include < string>
17
17
#include " paddle/fluid/operators/math/jit_kernel_macro.h"
18
18
#include " paddle/fluid/platform/enforce.h"
19
+ #include " paddle/fluid/platform/macros.h"
19
20
20
21
#ifdef __AVX__
21
22
#include < immintrin.h>
@@ -24,10 +25,63 @@ limitations under the License. */
24
25
namespace paddle {
25
26
namespace operators {
26
27
namespace math {
27
- namespace jitkernel {
28
+ #ifdef __AVX__
29
+ namespace detail {
30
+ __m256 Exp (__m256 a);
31
+ } // namespace detail
32
+ #endif
28
33
34
+ namespace jitkernel {
29
35
namespace jit = platform::jit;
30
36
37
+ #ifdef __AVX__
38
+ typedef enum { kSigmoid , kRelu , kTanh , kIdentity } act_type;
39
+
40
+ class AVXAct {
41
+ public:
42
+ virtual ~AVXAct () = default ;
43
+ virtual __m256 Compute (__m256 x) const = 0;
44
+ };
45
+
46
+ template <act_type type>
47
+ class AVXActImpl : public AVXAct {
48
+ public:
49
+ __m256 Compute (__m256 x) const override { PADDLE_THROW (" Unkown type!" ); }
50
+ };
51
+
52
+ template <>
53
+ __m256 AVXActImpl<kSigmoid >::Compute(__m256 x) const {
54
+ __m256 ones = _mm256_set1_ps (1 .0f );
55
+ x = _mm256_max_ps (x, _mm256_set1_ps (SIGMOID_THRESHOLD_MIN));
56
+ x = _mm256_min_ps (x, _mm256_set1_ps (SIGMOID_THRESHOLD_MAX));
57
+ x = _mm256_sub_ps (_mm256_set1_ps (0 .0f ), x);
58
+ x = detail::Exp (x);
59
+ x = _mm256_add_ps (ones, x);
60
+ return _mm256_div_ps (ones, x);
61
+ }
62
+
63
+ template <>
64
+ __m256 AVXActImpl<kTanh >::Compute(__m256 x) const {
65
+ __m256 ones = _mm256_set1_ps (1 .0f );
66
+ x = _mm256_mul_ps (_mm256_set1_ps (-2 .0f ), x);
67
+ x = _mm256_min_ps (x, _mm256_set1_ps (EXP_MAX_INPUT));
68
+ x = detail::Exp (x);
69
+ x = _mm256_add_ps (ones, x);
70
+ x = _mm256_div_ps (_mm256_set1_ps (2 .0f ), x);
71
+ return _mm256_sub_ps (x, ones);
72
+ }
73
+
74
+ template <>
75
+ __m256 AVXActImpl<kRelu >::Compute(__m256 x) const {
76
+ return _mm256_max_ps (x, _mm256_setzero_ps ());
77
+ }
78
+
79
+ template <>
80
+ __m256 AVXActImpl<kIdentity >::Compute(__m256 x) const {
81
+ return x;
82
+ }
83
+ #endif
84
+
31
85
/* LSTM JitKernel */
32
86
template <typename T, jit::cpu_isa_t isa, jit_block>
33
87
class LSTMKernelImpl : public LSTMKernel <T> {
@@ -61,6 +115,23 @@ class LSTMKernelImpl : public LSTMKernel<T> {
61
115
act_cell_d_ = GetActKernel (act_cell, d);
62
116
vmul_d_ = KernelPool::Instance ().template Get <VMulKernel<T>>(d);
63
117
vadd_d_ = KernelPool::Instance ().template Get <VAddKernel<T>>(d);
118
+ #ifdef __AVX__
119
+ auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> {
120
+ if (type == " sigmoid" ) {
121
+ return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid >());
122
+ } else if (type == " relu" ) {
123
+ return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu >());
124
+ } else if (type == " tanh" ) {
125
+ return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh >());
126
+ } else if (type == " identity" || type == " " ) {
127
+ return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity >());
128
+ }
129
+ PADDLE_THROW (" Not support type: %s" , type);
130
+ };
131
+ avx_act_gate_ = GetAVXAct (act_gate);
132
+ avx_act_cand_ = GetAVXAct (act_cand);
133
+ avx_act_cell_ = GetAVXAct (act_cell);
134
+ #endif
64
135
}
65
136
66
137
void ComputeCtHt (T* gates, const T* ct_1, T* ct, T* ht) const override {
@@ -83,8 +154,44 @@ class LSTMKernelImpl : public LSTMKernel<T> {
83
154
std::shared_ptr<const VActKernel<T>> act_gate_3d_, act_cand_d_, act_cell_d_;
84
155
std::shared_ptr<const VMulKernel<T>> vmul_d_;
85
156
std::shared_ptr<const VAddKernel<T>> vadd_d_;
157
+ #ifdef __AVX__
158
+ std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_cand_, avx_act_cell_;
159
+ #endif
86
160
};
87
161
162
+ #define INTRI8_FLOAT (isa ) \
163
+ template <> \
164
+ void LSTMKernelImpl<float , isa, kEQ8 >::ComputeCtHt( \
165
+ float * gates, const float * ct_1, float * ct, float * ht) const { \
166
+ /* gates: W_ch, W_ih, W_fh, W_oh */ \
167
+ __m256 c, i, f, o; \
168
+ c = _mm256_loadu_ps (gates); \
169
+ i = _mm256_loadu_ps (gates + 8 ); \
170
+ f = _mm256_loadu_ps (gates + 16 ); \
171
+ o = _mm256_loadu_ps (gates + 24 ); \
172
+ /* C_t = C_t-1 * fgated + cand_gated * igated*/ \
173
+ c = _mm256_mul_ps (avx_act_cand_->Compute (c), avx_act_gate_->Compute (i)); \
174
+ i = _mm256_loadu_ps (ct_1); \
175
+ f = _mm256_mul_ps (i, avx_act_gate_->Compute (f)); \
176
+ f = _mm256_add_ps (c, f); \
177
+ _mm256_storeu_ps (ct, f); \
178
+ /* H_t = act_cell(C_t) * ogated */ \
179
+ o = _mm256_mul_ps (avx_act_cell_->Compute (f), avx_act_gate_->Compute (o)); \
180
+ _mm256_storeu_ps (ht, o); \
181
+ }
182
+
183
+ // TODO(TJ): optimize keq16
184
+
185
+ #ifdef __AVX__
186
+ INTRI8_FLOAT (jit::avx);
187
+ #endif
188
+ #ifdef __AVX2__
189
+ INTRI8_FLOAT (jit::avx2);
190
+ #endif
191
+ #ifdef __AVX512F__
192
+ INTRI8_FLOAT (jit::avx512f);
193
+ #endif
194
+
88
195
#define JITKERNEL_DECLARE_LSTM (ker_class, ker_dtype ) \
89
196
template <> \
90
197
std::shared_ptr<const ker_class<ker_dtype>> \
@@ -104,6 +211,7 @@ class LSTMKernelImpl : public LSTMKernel<T> {
104
211
REGISTER_JITKERNEL_ARGS (lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
105
212
JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL);
106
213
214
+ #undef INTRI8_FLOAT
107
215
#undef JITKERNEL_DECLARE_LSTM
108
216
#undef JITKERNEL_KEY_LSTM
109
217
#undef JITKERNEL_NEW_LSTM_IMPL
0 commit comments