@@ -71,6 +71,13 @@ void VAddBiasRefer(const T* a, const T* x, T* y, int n) {
71
71
}
72
72
}
73
73
74
+ template <typename T>
75
+ void VReluRefer (const T* x, T* y, int n) {
76
+ for (int i = 0 ; i < n; ++i) {
77
+ y[i] = x[i] > 0 ? x[i] : 0 ;
78
+ }
79
+ }
80
+
74
81
#ifdef PADDLE_WITH_MKLML
75
82
template <typename T>
76
83
void VMulMKL (const T* x, const T* y, T* z, int n);
@@ -344,124 +351,60 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
344
351
}
345
352
#endif
346
353
347
- #undef DECLARE_STATIC_FUNC
348
-
349
- REGISTER_JITKERNEL (vmul, VMulKernel);
350
- REGISTER_JITKERNEL (vadd, VAddKernel);
351
- REGISTER_JITKERNEL (vaddrelu, VAddReluKernel);
352
- REGISTER_JITKERNEL (vscal, VScalKernel);
353
- REGISTER_JITKERNEL (vaddbias, VAddBiasKernel);
354
-
355
354
/* VRelu JitKernel */
356
- template <typename T, platform::jit:: cpu_isa_t isa, jit_block >
355
+ template <typename T>
357
356
class VReluKernelImpl : public VReluKernel <T> {
358
357
public:
359
- explicit VReluKernelImpl (int d) : VReluKernel<T>() { this ->num_ = d; }
360
- void Compute (const T* x, T* y) const override {
361
- for (int i = 0 ; i < this ->num_ ; ++i) {
362
- y[i] = x[i] > 0 ? x[i] : 0 ;
358
+ DECLARE_STATIC_FUNC;
359
+ explicit VReluKernelImpl (int d) : VReluKernel<T>() {
360
+ this ->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
361
+ #ifdef PADDLE_WITH_XBYAK
362
+ if (useJIT (d)) {
363
+ size_t sz = 96 /* init*/ +
364
+ d / AVX_FLOAT_BLOCK * 4 /* instructions*/ *
365
+ 8 /* everage byte for each instruction*/ ;
366
+ jitcode_.reset (new gen::ReluJitCode (d, sz > 4096 ? sz : 4096 ));
367
+ this ->Compute = jitcode_->getCode <void (*)(const T*, T*, int )>();
368
+ return ;
363
369
}
364
- }
365
- };
366
-
367
- #define INTRI8_FLOAT (isa ) \
368
- template <> \
369
- void VReluKernelImpl<float , isa, kEQ8 >::Compute(const float * x, float * y) \
370
- const { \
371
- __m256 tmp = _mm256_loadu_ps (x); \
372
- tmp = _mm256_max_ps (tmp, _mm256_setzero_ps ()); \
373
- _mm256_storeu_ps (y, tmp); \
374
- }
375
-
376
- #define INTRI16_FLOAT (isa ) \
377
- template <> \
378
- void VReluKernelImpl<float , isa, kEQ16 >::Compute(const float * x, float * y) \
379
- const { \
380
- __m256 zeros = _mm256_setzero_ps (); \
381
- __m256 tmp0 = _mm256_loadu_ps (x); \
382
- __m256 tmp1 = _mm256_loadu_ps (x + 8 ); \
383
- tmp0 = _mm256_max_ps (tmp0, zeros); \
384
- tmp1 = _mm256_max_ps (tmp1, zeros); \
385
- _mm256_storeu_ps (y, tmp0); \
386
- _mm256_storeu_ps (y + 8 , tmp1); \
387
- }
370
+ #endif
388
371
389
- #define INTRI_GT8LT16_FLOAT (isa ) \
390
- template <> \
391
- VReluKernelImpl<float , isa, kGT8LT16 >::VReluKernelImpl(int d) \
392
- : VReluKernel<float >() { \
393
- this ->num_ = d; \
394
- this ->end_ = AVX_FLOAT_BLOCK; \
395
- this ->rest_ = d - AVX_FLOAT_BLOCK; \
396
- } \
397
- template <> \
398
- void VReluKernelImpl<float , isa, kGT8LT16 >::Compute(const float * x, \
399
- float * y) const { \
400
- __m256 zeros = _mm256_setzero_ps (); \
401
- __m256 tmp0 = _mm256_loadu_ps (x); \
402
- __m256 tmp1 = _mm256_loadu_ps (x + this ->rest_ ); \
403
- tmp0 = _mm256_max_ps (tmp0, zeros); \
404
- tmp1 = _mm256_max_ps (tmp1, zeros); \
405
- _mm256_storeu_ps (y, tmp0); \
406
- _mm256_storeu_ps (y + this ->rest_ , tmp1); \
372
+ this ->Compute = VReluRefer<T>;
407
373
}
408
-
409
- #define INTRI_GT16_FLOAT (isa ) \
410
- template <> \
411
- VReluKernelImpl<float , isa, kGT16 >::VReluKernelImpl(int d) \
412
- : VReluKernel<float >() { \
413
- this ->num_ = d; \
414
- this ->end_ = d - d % AVX_FLOAT_BLOCK; \
415
- this ->rest_ = d - AVX_FLOAT_BLOCK; \
416
- } \
417
- template <> \
418
- void VReluKernelImpl<float , isa, kGT16 >::Compute(const float * x, float * y) \
419
- const { \
420
- __m256 zeros = _mm256_setzero_ps (); \
421
- for (int i = 0 ; i < this ->end_ ; i += AVX_FLOAT_BLOCK) { \
422
- __m256 tmp = _mm256_loadu_ps (x + i); \
423
- tmp = _mm256_max_ps (tmp, zeros); \
424
- _mm256_storeu_ps (y + i, tmp); \
425
- } \
426
- __m256 tmp = _mm256_loadu_ps (x + this ->rest_ ); \
427
- tmp = _mm256_max_ps (tmp, zeros); \
428
- _mm256_storeu_ps (y + this ->rest_ , tmp); \
374
+ void ComputeDeprecated (const T* x, T* y) const override {
375
+ VReluRefer (x, y, this ->num_ );
429
376
}
377
+ #ifdef PADDLE_WITH_XBYAK
430
378
431
- #ifdef __AVX__
432
- INTRI8_FLOAT (jit::avx);
433
- INTRI16_FLOAT (jit::avx);
434
- INTRI_GT8LT16_FLOAT (jit::avx);
435
- INTRI_GT16_FLOAT (jit::avx);
436
- #endif
437
- #ifdef __AVX2__
438
- INTRI8_FLOAT (jit::avx2);
439
- INTRI16_FLOAT (jit::avx2);
440
- INTRI_GT8LT16_FLOAT (jit::avx2);
441
- INTRI_GT16_FLOAT (jit::avx2);
379
+ private:
380
+ std::unique_ptr<gen::ReluJitCode> jitcode_{nullptr };
442
381
#endif
443
- #ifdef __AVX512F__
444
- // TODO(TJ): refine avx512
445
- INTRI8_FLOAT (jit::avx512f);
446
- INTRI16_FLOAT (jit::avx512f);
447
- INTRI_GT8LT16_FLOAT (jit::avx512f);
448
- INTRI_GT16_FLOAT (jit::avx512f);
382
+ };
383
+
384
+ #ifdef PADDLE_WITH_XBYAK
385
+ template <>
386
+ bool VReluKernelImpl<float >::useJIT(int d) {
387
+ return gen::ReluJitCode::init (d);
388
+ }
449
389
#endif
450
390
451
- #undef INTRI8_FLOAT
452
- #undef INTRI16_FLOAT
453
- #undef INTRI_GT8LT16_FLOAT
454
- #undef INTRI_GT16_FLOAT
391
+ #undef DECLARE_STATIC_FUNC
392
+
393
+ REGISTER_JITKERNEL (vmul, VMulKernel);
394
+ REGISTER_JITKERNEL (vadd, VAddKernel);
395
+ REGISTER_JITKERNEL (vaddrelu, VAddReluKernel);
396
+ REGISTER_JITKERNEL (vscal, VScalKernel);
397
+ REGISTER_JITKERNEL (vaddbias, VAddBiasKernel);
398
+ REGISTER_JITKERNEL (vrelu, VReluKernel);
455
399
456
400
/* An empty JitKernel */
457
401
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
458
402
class VIdentityKernelImpl : public VIdentityKernel <T> {
459
403
public:
460
404
explicit VIdentityKernelImpl (int d) : VIdentityKernel<T>() { this ->num_ = d; }
461
- void Compute (const T* x, T* y) const override {}
405
+ void ComputeDeprecated (const T* x, T* y) const override {}
462
406
};
463
407
464
- REGISTER_JITKERNEL_DEPRECATED (vrelu, VReluKernel);
465
408
REGISTER_JITKERNEL_DEPRECATED (videntity, VIdentityKernel);
466
409
467
410
} // namespace jitkernel
0 commit comments