@@ -45,6 +45,7 @@ void VExpRefer(const T* x, T* y, int n) {
45
45
46
46
template <typename T>
47
47
void VSigmoidRefer (const T* x, T* y, int n) {
48
+ // y = 1 / (1 + e^-x)
48
49
const T min = SIGMOID_THRESHOLD_MIN;
49
50
const T max = SIGMOID_THRESHOLD_MAX;
50
51
for (int i = 0 ; i < n; ++i) {
@@ -53,6 +54,18 @@ void VSigmoidRefer(const T* x, T* y, int n) {
53
54
}
54
55
}
55
56
57
+ template <typename T>
58
+ void VTanhRefer (const T* x, T* y, int n) {
59
+ // y = 2 * sigmoid(2x) - 1
60
+ for (int i = 0 ; i < n; ++i) {
61
+ y[i] = static_cast <T>(2 ) * x[i];
62
+ }
63
+ VSigmoidRefer (y, y, n);
64
+ for (int i = 0 ; i < n; ++i) {
65
+ y[i] = static_cast <T>(2 ) * y[i] - static_cast <T>(1 );
66
+ }
67
+ }
68
+
56
69
#ifdef PADDLE_WITH_MKLML
57
70
template <typename T>
58
71
void VExpMKL (const T* x, T* y, int n);
@@ -80,6 +93,17 @@ void VSigmoidMKL(const T* x, T* y, int n) {
80
93
y[i] = static_cast <T>(1 ) / (static_cast <T>(1 ) + y[i]);
81
94
}
82
95
}
96
+
97
+ template <typename T>
98
+ void VTanhMKL (const T* x, T* y, int n) {
99
+ for (int i = 0 ; i < n; ++i) {
100
+ y[i] = static_cast <T>(2 ) * x[i];
101
+ }
102
+ VSigmoidMKL (y, y, n);
103
+ for (int i = 0 ; i < n; ++i) {
104
+ y[i] = static_cast <T>(2 ) * y[i] - static_cast <T>(1 );
105
+ }
106
+ }
83
107
#endif
84
108
85
109
/* VExp JitKernel */
@@ -189,8 +213,63 @@ bool VSigmoidKernelImpl<double>::useMKL(int d) {
189
213
}
190
214
#endif
191
215
216
+ /* VTanh JitKernel */
217
+ template <typename T>
218
+ class VTanhKernelImpl : public VTanhKernel <T> {
219
+ public:
220
+ JITKERNEL_DECLARE_STATIC_FUNC;
221
+ explicit VTanhKernelImpl (int d) : VTanhKernel<T>() {
222
+ this ->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
223
+ #ifdef PADDLE_WITH_XBYAK
224
+ if (useJIT (d)) {
225
+ size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8 ; // should change
226
+ jitcode_.reset (new gen::VTanhJitCode (d, sz > 4096 ? sz : 4096 ));
227
+ this ->Compute = jitcode_->getCode <void (*)(const T*, T*, int )>();
228
+ return ;
229
+ }
230
+ #endif
231
+
232
+ #ifdef PADDLE_WITH_MKLML
233
+ // strictly it's a better impl with MKL, then is refer
234
+ if (useMKL (d)) {
235
+ this ->Compute = VTanhMKL<T>;
236
+ return ;
237
+ }
238
+ #endif
239
+ this ->Compute = VTanhRefer<T>;
240
+ }
241
+ void ComputeDeprecated (const T* x, T* y) const override {
242
+ VTanhRefer (x, y, this ->num_ );
243
+ }
244
+ #ifdef PADDLE_WITH_XBYAK
245
+
246
+ private:
247
+ std::unique_ptr<gen::VTanhJitCode> jitcode_{nullptr };
248
+ #endif
249
+ };
250
+
251
+ #ifdef PADDLE_WITH_XBYAK
252
+ template <>
253
+ bool VTanhKernelImpl<float >::useJIT(int d) {
254
+ return gen::VTanhJitCode::init (d);
255
+ }
256
+ #endif
257
+
258
+ #ifdef PADDLE_WITH_MKLML
259
+ template <>
260
+ bool VTanhKernelImpl<float >::useMKL(int d) {
261
+ return d > 512 ;
262
+ }
263
+
264
+ template <>
265
+ bool VTanhKernelImpl<double >::useMKL(int d) {
266
+ return true ;
267
+ }
268
+ #endif
269
+
192
270
REGISTER_JITKERNEL (vexp, VExpKernel);
193
271
REGISTER_JITKERNEL (vsigmoid, VSigmoidKernel);
272
+ REGISTER_JITKERNEL (vtanh, VTanhKernel);
194
273
195
274
namespace detail {
196
275
@@ -337,156 +416,6 @@ __m256 ExpAVX2(__m256 x) {
337
416
#endif
338
417
339
418
} // namespace detail
340
-
341
- #define INTRI_SIGMOID (tmp, min, max, expisa ) \
342
- tmp = _mm256_max_ps(tmp, min); \
343
- tmp = _mm256_min_ps(tmp, max); \
344
- tmp = _mm256_sub_ps(_mm256_set1_ps(0 .0f ), tmp); \
345
- tmp = expisa(tmp); \
346
- tmp = _mm256_add_ps(_mm256_set1_ps(1 .0f ), tmp); \
347
- tmp = _mm256_div_ps(_mm256_set1_ps(1 .0f ), tmp)
348
- #undef INTRI_VSIGMOID
349
-
350
- /* VTanh JitKernel */
351
- template <typename T, jit::cpu_isa_t isa, jit_block>
352
- class VTanhKernelImpl : public VTanhKernel <T> {
353
- public:
354
- explicit VTanhKernelImpl (int d) : VTanhKernel<T>() {
355
- this ->num_ = d;
356
- vscal_ = KernelPool::Instance ().template Get <VScalKernel<T>>(d);
357
- vsigmoid_ = KernelPool::Instance ().template Get <VSigmoidKernel<T>>(d);
358
- vaddbias_ = KernelPool::Instance ().template Get <VAddBiasKernel<T>>(d);
359
- }
360
- void ComputeDeprecated (const T* x, T* y) const override {
361
- const T a = static_cast <T>(2 ), b = static_cast <T>(-1 );
362
- vscal_->Compute (&a, x, y, this ->num_ );
363
- vsigmoid_->ComputeDeprecated (y, y);
364
- vscal_->Compute (&a, y, y, this ->num_ );
365
- vaddbias_->Compute (&b, y, y, this ->num_ );
366
- }
367
-
368
- private:
369
- std::shared_ptr<const VScalKernel<T>> vscal_;
370
- std::shared_ptr<const VSigmoidKernel<T>> vsigmoid_;
371
- std::shared_ptr<const VAddBiasKernel<T>> vaddbias_;
372
- };
373
-
374
- #define INTRI_VTANH (tmp, expisa ) \
375
- tmp = _mm256_mul_ps(_mm256_set1_ps(-2 .0f ), tmp); \
376
- tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \
377
- tmp = expisa(tmp); \
378
- tmp = _mm256_add_ps(_mm256_set1_ps(1 .0f ), tmp); \
379
- tmp = _mm256_div_ps(_mm256_set1_ps(2 .0f ), tmp); \
380
- tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1 .0f ))
381
-
382
- #define INTRI8_FLOAT (isa, expisa ) \
383
- template <> \
384
- void VTanhKernelImpl<float , isa, kEQ8 >::ComputeDeprecated(const float * x, \
385
- float * y) const { \
386
- __m256 tmp = _mm256_loadu_ps (x); \
387
- INTRI_VTANH (tmp, expisa); \
388
- _mm256_storeu_ps (y, tmp); \
389
- }
390
-
391
- #define INTRI16_FLOAT (isa, expisa ) \
392
- template <> \
393
- void VTanhKernelImpl<float , isa, kEQ16 >::ComputeDeprecated(const float * x, \
394
- float * y) const { \
395
- __m256 tmp0 = _mm256_loadu_ps (x); \
396
- __m256 tmp1 = _mm256_loadu_ps (x + 8 ); \
397
- INTRI_VTANH (tmp0, expisa); \
398
- INTRI_VTANH (tmp1, expisa); \
399
- _mm256_storeu_ps (y, tmp0); \
400
- _mm256_storeu_ps (y + 8 , tmp1); \
401
- }
402
-
403
- #define INTRI_GT8LT16_FLOAT (isa, expisa ) \
404
- template <> \
405
- VTanhKernelImpl<float , isa, kGT8LT16 >::VTanhKernelImpl(int d) \
406
- : VTanhKernel<float >() { \
407
- this ->num_ = d; \
408
- this ->end_ = AVX_FLOAT_BLOCK; \
409
- this ->rest_ = d - this ->end_ ; \
410
- vscal_ = \
411
- KernelPool::Instance ().template Get <VScalKernel<float >>(this ->rest_ ); \
412
- vsigmoid_ = KernelPool::Instance ().template Get <VSigmoidKernel<float >>( \
413
- this ->rest_ ); \
414
- vaddbias_ = KernelPool::Instance ().template Get <VAddBiasKernel<float >>( \
415
- this ->rest_ ); \
416
- } \
417
- template <> \
418
- void VTanhKernelImpl<float , isa, kGT8LT16 >::ComputeDeprecated( \
419
- const float * x, float * y) const { \
420
- __m256 tmp = _mm256_loadu_ps (x); \
421
- INTRI_VTANH (tmp, expisa); \
422
- _mm256_storeu_ps (y, tmp); \
423
- x += AVX_FLOAT_BLOCK; \
424
- y += AVX_FLOAT_BLOCK; \
425
- const float a = 2 .f , b = -1 .f ; \
426
- vscal_->Compute (&a, x, y, this ->num_ ); \
427
- vsigmoid_->ComputeDeprecated (y, y); \
428
- vscal_->Compute (&a, y, y, this ->num_ ); \
429
- vaddbias_->Compute (&b, y, y, this ->num_ ); \
430
- }
431
-
432
- #define INTRI_GT16_FLOAT (isa, expisa ) \
433
- template <> \
434
- VTanhKernelImpl<float , isa, kGT16 >::VTanhKernelImpl(int d) \
435
- : VTanhKernel<float >() { \
436
- this ->num_ = d; \
437
- this ->rest_ = d % AVX_FLOAT_BLOCK; \
438
- this ->end_ = d - this ->rest_ ; \
439
- vscal_ = \
440
- KernelPool::Instance ().template Get <VScalKernel<float >>(this ->rest_ ); \
441
- vsigmoid_ = KernelPool::Instance ().template Get <VSigmoidKernel<float >>( \
442
- this ->rest_ ); \
443
- vaddbias_ = KernelPool::Instance ().template Get <VAddBiasKernel<float >>( \
444
- this ->rest_ ); \
445
- } \
446
- template <> \
447
- void VTanhKernelImpl<float , isa, kGT16 >::ComputeDeprecated(const float * x, \
448
- float * y) const { \
449
- for (int i = 0 ; i < this ->end_ ; i += AVX_FLOAT_BLOCK) { \
450
- __m256 tmp = _mm256_loadu_ps (x + i); \
451
- INTRI_VTANH (tmp, expisa); \
452
- _mm256_storeu_ps (y + i, tmp); \
453
- } \
454
- x += this ->end_ ; \
455
- y += this ->end_ ; \
456
- const float a = 2 .f , b = -1 .f ; \
457
- vscal_->Compute (&a, x, y, this ->num_ ); \
458
- vsigmoid_->ComputeDeprecated (y, y); \
459
- vscal_->Compute (&a, y, y, this ->num_ ); \
460
- vaddbias_->Compute (&b, y, y, this ->num_ ); \
461
- }
462
-
463
- #ifdef __AVX__
464
- INTRI8_FLOAT (jit::avx, detail::ExpAVX);
465
- INTRI16_FLOAT (jit::avx, detail::ExpAVX);
466
- INTRI_GT8LT16_FLOAT (jit::avx, detail::ExpAVX);
467
- INTRI_GT16_FLOAT (jit::avx, detail::ExpAVX);
468
- #endif
469
- #ifdef __AVX2__
470
- INTRI8_FLOAT (jit::avx2, detail::ExpAVX2);
471
- INTRI16_FLOAT (jit::avx2, detail::ExpAVX2);
472
- // maybe use avx at gt8lt16 and gt16
473
- #endif
474
- #ifdef __AVX512F__
475
- INTRI8_FLOAT (jit::avx512f, detail::ExpAVX2);
476
- INTRI16_FLOAT (jit::avx512f, detail::ExpAVX2);
477
- // maybe use avx at gt8lt16 and gt16
478
- #endif
479
-
480
- #undef INTRI8_FLOAT
481
- #undef INTRI16_FLOAT
482
- #undef INTRI_GT8LT16_FLOAT
483
- #undef INTRI_GT16_FLOAT
484
- #undef INTRI_VTANH
485
-
486
- REGISTER_JITKERNEL_DEPRECATED (vtanh, VTanhKernel);
487
-
488
- #undef JITKERNEL_NEW_ACT_IMPL
489
-
490
419
} // namespace jitkernel
491
420
} // namespace math
492
421
} // namespace operators
0 commit comments