@@ -136,6 +136,21 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel(
136
136
return nullptr ;
137
137
}
138
138
139
+ template <jit::cpu_isa_t isa>
140
+ static std::unique_ptr<AVXAct> GetAVXAct (const std::string& type) {
141
+ if (type == " sigmoid" ) {
142
+ return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid , isa>());
143
+ } else if (type == " relu" ) {
144
+ return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu , isa>());
145
+ } else if (type == " tanh" ) {
146
+ return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh , isa>());
147
+ } else if (type == " identity" || type == " " ) {
148
+ return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity , isa>());
149
+ }
150
+ PADDLE_THROW (" Not support type: %s" , type);
151
+ return nullptr ;
152
+ }
153
+
139
154
/* LSTM JitKernel */
140
155
template <typename T, jit::cpu_isa_t isa, jit_block>
141
156
class LSTMKernelImpl : public LSTMKernel <T> {
@@ -192,61 +207,49 @@ class LSTMKernelImpl : public LSTMKernel<T> {
192
207
#endif
193
208
};
194
209
195
- #define INTRI8_FLOAT (isa ) \
196
- template <> \
197
- LSTMKernelImpl<float , isa, kEQ8 >::LSTMKernelImpl( \
198
- const std::string& act_gate, const std::string& act_cand, \
199
- const std::string& act_cell, int d) \
200
- : LSTMKernel<float >() { \
201
- auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \
202
- if (type == " sigmoid" ) { \
203
- return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid , isa>()); \
204
- } else if (type == " relu" ) { \
205
- return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu , isa>()); \
206
- } else if (type == " tanh" ) { \
207
- return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh , isa>()); \
208
- } else if (type == " identity" || type == " " ) { \
209
- return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity , isa>()); \
210
- } \
211
- PADDLE_THROW (" Not support type: %s" , type); \
212
- }; \
213
- avx_act_gate_ = GetAVXAct (act_gate); \
214
- avx_act_cand_ = GetAVXAct (act_cand); \
215
- avx_act_cell_ = GetAVXAct (act_cell); \
216
- } \
217
- template <> \
218
- void LSTMKernelImpl<float , isa, kEQ8 >::ComputeCtHt( \
219
- float * gates, const float * ct_1, float * ct, float * ht, \
220
- const float * wp_data, float * checked) const { \
221
- /* gates: W_ch, W_ih, W_fh, W_oh */ \
222
- __m256 c, i, f, o; \
223
- c = _mm256_loadu_ps (gates); \
224
- i = _mm256_loadu_ps (gates + 8 ); \
225
- f = _mm256_loadu_ps (gates + 16 ); \
226
- o = _mm256_loadu_ps (gates + 24 ); \
227
- /* C_t = C_t-1 * fgated + cand_gated * igated*/ \
228
- c = _mm256_mul_ps (avx_act_cand_->Compute (c), avx_act_gate_->Compute (i)); \
229
- i = _mm256_loadu_ps (ct_1); \
230
- f = _mm256_mul_ps (i, avx_act_gate_->Compute (f)); \
231
- f = _mm256_add_ps (c, f); \
232
- _mm256_storeu_ps (ct, f); \
233
- /* H_t = act_cell(C_t) * ogated */ \
234
- o = _mm256_mul_ps (avx_act_cell_->Compute (f), avx_act_gate_->Compute (o)); \
235
- _mm256_storeu_ps (ht, o); \
236
- } \
237
- template <> \
238
- void LSTMKernelImpl<float , isa, kEQ8 >::ComputeC1H1( \
239
- float * gates, float * ct, float * ht, const float * wp_data) const { \
240
- __m256 c, i, o; \
241
- c = _mm256_loadu_ps (gates); \
242
- i = _mm256_loadu_ps (gates + 8 ); \
243
- o = _mm256_loadu_ps (gates + 24 ); \
244
- /* C_t = igated * cgated*/ \
245
- c = _mm256_mul_ps (avx_act_gate_->Compute (i), avx_act_cand_->Compute (c)); \
246
- _mm256_storeu_ps (ct, c); \
247
- /* H_t = act_cell(C_t) * ogated */ \
248
- o = _mm256_mul_ps (avx_act_cell_->Compute (c), avx_act_gate_->Compute (o)); \
249
- _mm256_storeu_ps (ht, o); \
210
+ #define INTRI8_FLOAT (isa ) \
211
+ template <> \
212
+ LSTMKernelImpl<float , isa, kEQ8 >::LSTMKernelImpl( \
213
+ const std::string& act_gate, const std::string& act_cand, \
214
+ const std::string& act_cell, int d) \
215
+ : LSTMKernel<float >() { \
216
+ avx_act_gate_ = GetAVXAct<isa>(act_gate); \
217
+ avx_act_cand_ = GetAVXAct<isa>(act_cand); \
218
+ avx_act_cell_ = GetAVXAct<isa>(act_cell); \
219
+ } \
220
+ template <> \
221
+ void LSTMKernelImpl<float , isa, kEQ8 >::ComputeCtHt( \
222
+ float * gates, const float * ct_1, float * ct, float * ht, \
223
+ const float * wp_data, float * checked) const { \
224
+ /* gates: W_ch, W_ih, W_fh, W_oh */ \
225
+ __m256 c, i, f, o; \
226
+ c = _mm256_loadu_ps (gates); \
227
+ i = _mm256_loadu_ps (gates + 8 ); \
228
+ f = _mm256_loadu_ps (gates + 16 ); \
229
+ o = _mm256_loadu_ps (gates + 24 ); \
230
+ /* C_t = C_t-1 * fgated + cand_gated * igated*/ \
231
+ c = _mm256_mul_ps (avx_act_cand_->Compute (c), avx_act_gate_->Compute (i)); \
232
+ i = _mm256_loadu_ps (ct_1); \
233
+ f = _mm256_mul_ps (i, avx_act_gate_->Compute (f)); \
234
+ f = _mm256_add_ps (c, f); \
235
+ _mm256_storeu_ps (ct, f); \
236
+ /* H_t = act_cell(C_t) * ogated */ \
237
+ o = _mm256_mul_ps (avx_act_cell_->Compute (f), avx_act_gate_->Compute (o)); \
238
+ _mm256_storeu_ps (ht, o); \
239
+ } \
240
+ template <> \
241
+ void LSTMKernelImpl<float , isa, kEQ8 >::ComputeC1H1( \
242
+ float * gates, float * ct, float * ht, const float * wp_data) const { \
243
+ __m256 c, i, o; \
244
+ c = _mm256_loadu_ps (gates); \
245
+ i = _mm256_loadu_ps (gates + 8 ); \
246
+ o = _mm256_loadu_ps (gates + 24 ); \
247
+ /* C_t = igated * cgated*/ \
248
+ c = _mm256_mul_ps (avx_act_gate_->Compute (i), avx_act_cand_->Compute (c)); \
249
+ _mm256_storeu_ps (ct, c); \
250
+ /* H_t = act_cell(C_t) * ogated */ \
251
+ o = _mm256_mul_ps (avx_act_cell_->Compute (c), avx_act_gate_->Compute (o)); \
252
+ _mm256_storeu_ps (ht, o); \
250
253
}
251
254
252
255
// TODO(TJ): optimize keq16
@@ -375,6 +378,7 @@ class GRUKernelImpl : public GRUKernel<T> {
375
378
act_state_d_->Compute (gates + d2_, gates + d2_);
376
379
vmul_d_->Compute (gates, gates + d2_, ht);
377
380
}
381
+
378
382
void ComputeHtPart1 (T* gates, const T* ht_1, T* ht) const override {
379
383
// W: {W_update, W_reset; W_state}
380
384
act_gate_d2_->Compute (gates, gates);
@@ -394,8 +398,65 @@ class GRUKernelImpl : public GRUKernel<T> {
394
398
int d_, d2_;
395
399
std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_state_d_;
396
400
std::shared_ptr<const VMulKernel<T>> vmul_d_;
401
+ #ifdef __AVX__
402
+ std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_state_;
403
+ #endif
397
404
};
398
405
406
+ #define INTRI8_FLOAT (isa ) \
407
+ template <> \
408
+ GRUKernelImpl<float , isa, kEQ8 >::GRUKernelImpl( \
409
+ const std::string& act_gate, const std::string& act_state, int d) \
410
+ : GRUKernel<float >() { \
411
+ avx_act_gate_ = GetAVXAct<isa>(act_gate); \
412
+ avx_act_state_ = GetAVXAct<isa>(act_state); \
413
+ } \
414
+ template <> \
415
+ void GRUKernelImpl<float , isa, kEQ8 >::ComputeH1(float * gates, float * ht) \
416
+ const { \
417
+ __m256 u, s; \
418
+ /* W: {W_update, W_reset; W_state} */ \
419
+ u = _mm256_loadu_ps (gates); \
420
+ s = _mm256_loadu_ps (gates + 16 ); \
421
+ s = _mm256_mul_ps (avx_act_gate_->Compute (u), avx_act_state_->Compute (s)); \
422
+ _mm256_storeu_ps (ht, s); \
423
+ } \
424
+ template <> \
425
+ void GRUKernelImpl<float , isa, kEQ8 >::ComputeHtPart1( \
426
+ float * gates, const float * ht_1, float * ht) const { \
427
+ /* not exactly equal the any implementation */ \
428
+ __m256 r, ht0; \
429
+ r = _mm256_loadu_ps (gates + 8 ); \
430
+ ht0 = _mm256_loadu_ps (ht_1); \
431
+ r = _mm256_mul_ps (avx_act_gate_->Compute (r), ht0); \
432
+ _mm256_storeu_ps (ht, r); \
433
+ } \
434
+ template <> \
435
+ void GRUKernelImpl<float , isa, kEQ8 >::ComputeHtPart2( \
436
+ float * gates, const float * ht_1, float * ht) const { \
437
+ /* not exactly equal the any implementation */ \
438
+ __m256 u, s, ht0; \
439
+ u = _mm256_loadu_ps (gates); \
440
+ s = _mm256_loadu_ps (gates + 16 ); \
441
+ ht0 = _mm256_loadu_ps (ht_1); \
442
+ u = avx_act_gate_->Compute (u); \
443
+ s = _mm256_mul_ps (u, avx_act_state_->Compute (s)); \
444
+ u = _mm256_sub_ps (_mm256_set1_ps (1 .f ), u); \
445
+ u = _mm256_mul_ps (u, ht0); \
446
+ u = _mm256_add_ps (s, u); \
447
+ _mm256_storeu_ps (ht, u); \
448
+ }
449
+
450
+ #ifdef __AVX__
451
+ INTRI8_FLOAT (jit::avx);
452
+ #endif
453
+ #ifdef __AVX2__
454
+ INTRI8_FLOAT (jit::avx2);
455
+ #endif
456
+ #ifdef __AVX512F__
457
+ INTRI8_FLOAT (jit::avx512f);
458
+ #endif
459
+
399
460
#define JITKERNEL_DECLARE_GRU (ker_class, ker_dtype ) \
400
461
template <> \
401
462
std::shared_ptr<const GRUKernel<ker_dtype>> KernelPool::Get< \
@@ -412,6 +473,7 @@ class GRUKernelImpl : public GRUKernel<T> {
412
473
REGISTER_JITKERNEL_ARGS (gru, GRUKernel, JITKERNEL_DECLARE_GRU,
413
474
JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL);
414
475
476
+ #undef INTRI8_FLOAT
415
477
#undef JITKERNEL_NEW_GRU_IMPL
416
478
#undef JITKERNEL_KEY_GRU
417
479
#undef JITKERNEL_DECLARE_GRU
0 commit comments