@@ -25,13 +25,18 @@ limitations under the License. */
25
25
namespace paddle {
26
26
namespace operators {
27
27
namespace math {
28
- # ifdef __AVX__
28
+ namespace jitkernel {
29
29
namespace detail {
30
- __m256 Exp (__m256 a);
31
- } // namespace detail
30
+ # ifdef __AVX__
31
+ __m256 ExpAVX (__m256 x);
32
32
#endif
33
33
34
- namespace jitkernel {
34
+ #ifdef __AVX2__
35
+ __m256 ExpAVX2 (__m256 x);
36
+ #endif
37
+
38
+ } // namespace detail
39
+
35
40
namespace jit = platform::jit;
36
41
37
42
#ifdef __AVX__
@@ -43,43 +48,72 @@ class AVXAct {
43
48
virtual __m256 Compute (__m256 x) const = 0;
44
49
};
45
50
46
- template <act_type type>
51
+ template <act_type type, jit:: cpu_isa_t isa >
47
52
class AVXActImpl : public AVXAct {
48
53
public:
49
54
__m256 Compute (__m256 x) const override { PADDLE_THROW (" Unkown type!" ); }
50
55
};
51
56
52
- template <>
53
- __m256 AVXActImpl<kSigmoid >::Compute(__m256 x) const {
54
- __m256 ones = _mm256_set1_ps (1 .0f );
55
- x = _mm256_max_ps (x, _mm256_set1_ps (SIGMOID_THRESHOLD_MIN));
56
- x = _mm256_min_ps (x, _mm256_set1_ps (SIGMOID_THRESHOLD_MAX));
57
- x = _mm256_sub_ps (_mm256_set1_ps (0 .0f ), x);
58
- x = detail::Exp (x);
59
- x = _mm256_add_ps (ones, x);
60
- return _mm256_div_ps (ones, x);
61
- }
57
+ #define AVX_SIGMOID (isa, expisa ) \
58
+ template <> \
59
+ __m256 AVXActImpl<kSigmoid , isa>::Compute(__m256 x) const { \
60
+ __m256 ones = _mm256_set1_ps (1 .0f ); \
61
+ x = _mm256_max_ps (x, _mm256_set1_ps (SIGMOID_THRESHOLD_MIN)); \
62
+ x = _mm256_min_ps (x, _mm256_set1_ps (SIGMOID_THRESHOLD_MAX)); \
63
+ x = _mm256_sub_ps (_mm256_set1_ps (0 .0f ), x); \
64
+ x = expisa (x); \
65
+ x = _mm256_add_ps (ones, x); \
66
+ return _mm256_div_ps (ones, x); \
67
+ }
62
68
63
- template <>
64
- __m256 AVXActImpl<kTanh >::Compute(__m256 x) const {
65
- __m256 ones = _mm256_set1_ps (1 .0f );
66
- x = _mm256_mul_ps (_mm256_set1_ps (-2 .0f ), x);
67
- x = _mm256_min_ps (x, _mm256_set1_ps (EXP_MAX_INPUT));
68
- x = detail::Exp (x);
69
- x = _mm256_add_ps (ones, x);
70
- x = _mm256_div_ps (_mm256_set1_ps (2 .0f ), x);
71
- return _mm256_sub_ps (x, ones);
72
- }
69
+ #define AVX_TANH (isa, expisa ) \
70
+ template <> \
71
+ __m256 AVXActImpl<kTanh , isa>::Compute(__m256 x) const { \
72
+ __m256 ones = _mm256_set1_ps (1 .0f ); \
73
+ x = _mm256_mul_ps (_mm256_set1_ps (-2 .0f ), x); \
74
+ x = _mm256_min_ps (x, _mm256_set1_ps (EXP_MAX_INPUT)); \
75
+ x = expisa (x); \
76
+ x = _mm256_add_ps (ones, x); \
77
+ x = _mm256_div_ps (_mm256_set1_ps (2 .0f ), x); \
78
+ return _mm256_sub_ps (x, ones); \
79
+ }
73
80
74
- template <>
75
- __m256 AVXActImpl<kRelu >::Compute(__m256 x) const {
76
- return _mm256_max_ps (x, _mm256_setzero_ps ());
77
- }
81
+ #define AVX_RELU (isa ) \
82
+ template <> \
83
+ __m256 AVXActImpl<kRelu , isa>::Compute(__m256 x) const { \
84
+ return _mm256_max_ps (x, _mm256_setzero_ps ()); \
85
+ }
86
+
87
+ #define AVX_IDENTITY (isa ) \
88
+ template <> \
89
+ __m256 AVXActImpl<kIdentity , isa>::Compute(__m256 x) const { \
90
+ return x; \
91
+ }
92
+
93
+ #define FOR_EACH_AVX_ISA (macro_ ) \
94
+ macro_ (jit::avx); \
95
+ macro_ (jit::avx2); \
96
+ macro_ (jit::avx512f)
97
+
98
+ FOR_EACH_AVX_ISA(AVX_RELU);
99
+ FOR_EACH_AVX_ISA (AVX_IDENTITY);
100
+
101
+ AVX_SIGMOID (jit::avx, detail::ExpAVX);
102
+ AVX_TANH (jit::avx, detail::ExpAVX);
103
+
104
+ #ifdef __AVX2__
105
+ AVX_SIGMOID (jit::avx2, detail::ExpAVX2);
106
+ AVX_SIGMOID (jit::avx512f, detail::ExpAVX2);
107
+ AVX_TANH (jit::avx2, detail::ExpAVX2);
108
+ AVX_TANH (jit::avx512f, detail::ExpAVX2);
109
+ #endif
110
+
111
+ #undef FOR_EACH_AVX_ISA
112
+ #undef AVX_IDENTITY
113
+ #undef AVX_RELU
114
+ #undef AVX_TANH
115
+ #undef AVX_SIGMOID
78
116
79
- template <>
80
- __m256 AVXActImpl<kIdentity >::Compute(__m256 x) const {
81
- return x;
82
- }
83
117
#endif
84
118
85
119
template <typename T>
@@ -119,23 +153,6 @@ class LSTMKernelImpl : public LSTMKernel<T> {
119
153
act_cell_d_ = GetActKernel<T>(act_cell, d);
120
154
vmul_d_ = KernelPool::Instance ().template Get <VMulKernel<T>>(d);
121
155
vadd_d_ = KernelPool::Instance ().template Get <VAddKernel<T>>(d);
122
- #ifdef __AVX__
123
- auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> {
124
- if (type == " sigmoid" ) {
125
- return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid >());
126
- } else if (type == " relu" ) {
127
- return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu >());
128
- } else if (type == " tanh" ) {
129
- return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh >());
130
- } else if (type == " identity" || type == " " ) {
131
- return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity >());
132
- }
133
- PADDLE_THROW (" Not support type: %s" , type);
134
- };
135
- avx_act_gate_ = GetAVXAct (act_gate);
136
- avx_act_cand_ = GetAVXAct (act_cand);
137
- avx_act_cell_ = GetAVXAct (act_cell);
138
- #endif
139
156
}
140
157
141
158
void ComputeCtHt (T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
@@ -175,26 +192,61 @@ class LSTMKernelImpl : public LSTMKernel<T> {
175
192
#endif
176
193
};
177
194
178
- #define INTRI8_FLOAT (isa ) \
179
- template <> \
180
- void LSTMKernelImpl<float , isa, kEQ8 >::ComputeCtHt( \
181
- float * gates, const float * ct_1, float * ct, float * ht, \
182
- const float * wp_data, float * checked) const { \
183
- /* gates: W_ch, W_ih, W_fh, W_oh */ \
184
- __m256 c, i, f, o; \
185
- c = _mm256_loadu_ps (gates); \
186
- i = _mm256_loadu_ps (gates + 8 ); \
187
- f = _mm256_loadu_ps (gates + 16 ); \
188
- o = _mm256_loadu_ps (gates + 24 ); \
189
- /* C_t = C_t-1 * fgated + cand_gated * igated*/ \
190
- c = _mm256_mul_ps (avx_act_cand_->Compute (c), avx_act_gate_->Compute (i)); \
191
- i = _mm256_loadu_ps (ct_1); \
192
- f = _mm256_mul_ps (i, avx_act_gate_->Compute (f)); \
193
- f = _mm256_add_ps (c, f); \
194
- _mm256_storeu_ps (ct, f); \
195
- /* H_t = act_cell(C_t) * ogated */ \
196
- o = _mm256_mul_ps (avx_act_cell_->Compute (f), avx_act_gate_->Compute (o)); \
197
- _mm256_storeu_ps (ht, o); \
195
+ #define INTRI8_FLOAT (isa ) \
196
+ template <> \
197
+ LSTMKernelImpl<float , isa, kEQ8 >::LSTMKernelImpl( \
198
+ const std::string& act_gate, const std::string& act_cand, \
199
+ const std::string& act_cell, int d) \
200
+ : LSTMKernel<float >() { \
201
+ auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \
202
+ if (type == " sigmoid" ) { \
203
+ return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid , isa>()); \
204
+ } else if (type == " relu" ) { \
205
+ return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu , isa>()); \
206
+ } else if (type == " tanh" ) { \
207
+ return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh , isa>()); \
208
+ } else if (type == " identity" || type == " " ) { \
209
+ return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity , isa>()); \
210
+ } \
211
+ PADDLE_THROW (" Not support type: %s" , type); \
212
+ }; \
213
+ avx_act_gate_ = GetAVXAct (act_gate); \
214
+ avx_act_cand_ = GetAVXAct (act_cand); \
215
+ avx_act_cell_ = GetAVXAct (act_cell); \
216
+ } \
217
+ template <> \
218
+ void LSTMKernelImpl<float , isa, kEQ8 >::ComputeCtHt( \
219
+ float * gates, const float * ct_1, float * ct, float * ht, \
220
+ const float * wp_data, float * checked) const { \
221
+ /* gates: W_ch, W_ih, W_fh, W_oh */ \
222
+ __m256 c, i, f, o; \
223
+ c = _mm256_loadu_ps (gates); \
224
+ i = _mm256_loadu_ps (gates + 8 ); \
225
+ f = _mm256_loadu_ps (gates + 16 ); \
226
+ o = _mm256_loadu_ps (gates + 24 ); \
227
+ /* C_t = C_t-1 * fgated + cand_gated * igated*/ \
228
+ c = _mm256_mul_ps (avx_act_cand_->Compute (c), avx_act_gate_->Compute (i)); \
229
+ i = _mm256_loadu_ps (ct_1); \
230
+ f = _mm256_mul_ps (i, avx_act_gate_->Compute (f)); \
231
+ f = _mm256_add_ps (c, f); \
232
+ _mm256_storeu_ps (ct, f); \
233
+ /* H_t = act_cell(C_t) * ogated */ \
234
+ o = _mm256_mul_ps (avx_act_cell_->Compute (f), avx_act_gate_->Compute (o)); \
235
+ _mm256_storeu_ps (ht, o); \
236
+ } \
237
+ template <> \
238
+ void LSTMKernelImpl<float , isa, kEQ8 >::ComputeC1H1( \
239
+ float * gates, float * ct, float * ht, const float * wp_data) const { \
240
+ __m256 c, i, o; \
241
+ c = _mm256_loadu_ps (gates); \
242
+ i = _mm256_loadu_ps (gates + 8 ); \
243
+ o = _mm256_loadu_ps (gates + 24 ); \
244
+ /* C_t = igated * cgated*/ \
245
+ c = _mm256_mul_ps (avx_act_gate_->Compute (i), avx_act_cand_->Compute (c)); \
246
+ _mm256_storeu_ps (ct, c); \
247
+ /* H_t = act_cell(C_t) * ogated */ \
248
+ o = _mm256_mul_ps (avx_act_cell_->Compute (c), avx_act_gate_->Compute (o)); \
249
+ _mm256_storeu_ps (ht, o); \
198
250
}
199
251
200
252
// TODO(TJ): optimize keq16
0 commit comments