@@ -46,6 +46,14 @@ void VAddRefer(const T* x, const T* y, T* z, int n) {
46
46
}
47
47
}
48
48
49
+ template <typename T>
50
+ void VAddReluRefer (const T* x, const T* y, T* z, int n) {
51
+ for (int i = 0 ; i < n; ++i) {
52
+ z[i] = x[i] + y[i];
53
+ z[i] = z[i] > 0 ? z[i] : 0 ;
54
+ }
55
+ }
56
+
49
57
#ifdef PADDLE_WITH_MKLML
50
58
template <typename T>
51
59
void VMulMKL (const T* x, const T* y, T* z, int n);
@@ -131,7 +139,7 @@ class VAddKernelImpl : public VAddKernel<T> {
131
139
explicit VAddKernelImpl (int d) : VAddKernel<T>() {
132
140
if (useJIT (d)) {
133
141
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8 ;
134
- jitcode_.reset (new gen::VAddJitCode (d, sz > 4096 ? sz : 4096 ));
142
+ jitcode_.reset (new gen::VAddJitCode (d, false , sz > 4096 ? sz : 4096 ));
135
143
this ->Compute =
136
144
jitcode_->getCode <void (*)(const T*, const T*, T*, int )>();
137
145
return ;
@@ -164,10 +172,36 @@ bool VAddKernelImpl<double>::useMKL(int d) {
164
172
return true ;
165
173
}
166
174
175
+ /* VAddRelu JitKernel */
176
+ template <typename T>
177
+ class VAddReluKernelImpl : public VAddReluKernel <T> {
178
+ public:
179
+ DECLARE_STATIC_FUNC;
180
+ explicit VAddReluKernelImpl (int d) : VAddReluKernel<T>() {
181
+ if (useJIT (d)) {
182
+ size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8 ;
183
+ jitcode_.reset (new gen::VAddJitCode (d, true , sz > 4096 ? sz : 4096 ));
184
+ this ->Compute =
185
+ jitcode_->getCode <void (*)(const T*, const T*, T*, int )>();
186
+ return ;
187
+ }
188
+ this ->Compute = VAddReluRefer<T>;
189
+ }
190
+
191
+ private:
192
+ std::unique_ptr<gen::VAddJitCode> jitcode_{nullptr };
193
+ };
194
+
195
+ template <>
196
+ bool VAddReluKernelImpl<float >::useJIT(int d) {
197
+ return gen::VAddJitCode::init (d);
198
+ }
199
+
167
200
#undef DECLARE_STATIC_FUNC
168
201
169
202
REGISTER_JITKERNEL (vmul, VMulKernel);
170
203
REGISTER_JITKERNEL (vadd, VAddKernel);
204
+ REGISTER_JITKERNEL (vaddrelu, VAddReluKernel);
171
205
172
206
/* VSCAL JitKernel */
173
207
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
@@ -404,97 +438,9 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
404
438
void Compute (const T* x, T* y) const override {}
405
439
};
406
440
407
- /* VAddRelu JitKernel */
408
- template <typename T, platform::jit::cpu_isa_t isa, jit_block>
409
- class VAddReluKernelImpl : public VAddReluKernel <T> {
410
- public:
411
- explicit VAddReluKernelImpl (int d) : VAddReluKernel<T>() { this ->num_ = d; }
412
- void Compute (const T* x, const T* y, T* z) const override {
413
- for (int i = 0 ; i < this ->num_ ; ++i) {
414
- z[i] = x[i] + y[i];
415
- z[i] = z[i] > 0 ? z[i] : 0 ;
416
- }
417
- }
418
- };
419
-
420
- #define INTRI8_FLOAT (isa ) \
421
- template <> \
422
- void VAddReluKernelImpl<float , isa, kEQ8 >::Compute( \
423
- const float * x, const float * y, float * z) const { \
424
- __m256 tmpx = _mm256_loadu_ps (x); \
425
- __m256 tmpy = _mm256_loadu_ps (y); \
426
- tmpy = _mm256_add_ps (tmpx, tmpy); \
427
- tmpy = _mm256_max_ps (tmpy, _mm256_setzero_ps ()); \
428
- _mm256_storeu_ps (z, tmpy); \
429
- }
430
-
431
- #define INTRI16_FLOAT (isa ) \
432
- template <> \
433
- void VAddReluKernelImpl<float , isa, kEQ16 >::Compute( \
434
- const float * x, const float * y, float * z) const { \
435
- __m256 zeros = _mm256_setzero_ps (); \
436
- __m256 tmp0 = _mm256_loadu_ps (x); \
437
- __m256 tmp1 = _mm256_loadu_ps (y); \
438
- tmp0 = _mm256_add_ps (tmp0, tmp1); \
439
- tmp0 = _mm256_max_ps (tmp0, zeros); \
440
- tmp1 = _mm256_loadu_ps (x + 8 ); \
441
- __m256 tmp2 = _mm256_loadu_ps (y + 8 ); \
442
- tmp1 = _mm256_add_ps (tmp1, tmp2); \
443
- tmp1 = _mm256_max_ps (tmp1, zeros); \
444
- _mm256_storeu_ps (z, tmp0); \
445
- _mm256_storeu_ps (z + 8 , tmp1); \
446
- }
447
-
448
- #define INTRI_COMMON_FLOAT (isa, block ) \
449
- template <> \
450
- VAddReluKernelImpl<float , isa, block>::VAddReluKernelImpl(int d) \
451
- : VAddReluKernel<float >() { \
452
- this ->num_ = d; \
453
- this ->end_ = d - d % AVX_FLOAT_BLOCK; \
454
- this ->rest_ = d - this ->end_ ; \
455
- } \
456
- template <> \
457
- void VAddReluKernelImpl<float , isa, block>::Compute( \
458
- const float * x, const float * y, float * z) const { \
459
- __m256 zeros = _mm256_setzero_ps (); \
460
- for (int i = 0 ; i < this ->end_ ; i += AVX_FLOAT_BLOCK) { \
461
- __m256 tmpx = _mm256_loadu_ps (x + i); \
462
- __m256 tmpy = _mm256_loadu_ps (y + i); \
463
- tmpy = _mm256_add_ps (tmpx, tmpy); \
464
- tmpy = _mm256_max_ps (tmpy, zeros); \
465
- _mm256_storeu_ps (z + i, tmpy); \
466
- } \
467
- for (int i = this ->end_ ; i < this ->num_ ; ++i) { \
468
- z[i] = x[i] + y[i]; \
469
- z[i] = z[i] > 0 ? z[i] : 0 ; \
470
- } \
471
- }
472
-
473
- #ifdef __AVX__
474
- INTRI8_FLOAT (jit::avx);
475
- INTRI16_FLOAT (jit::avx);
476
- INTRI_COMMON_FLOAT (jit::avx, kGT16 );
477
- #endif
478
- #ifdef __AVX2__
479
- INTRI8_FLOAT (jit::avx2);
480
- INTRI16_FLOAT (jit::avx2);
481
- INTRI_COMMON_FLOAT (jit::avx2, kGT16 );
482
- #endif
483
- #ifdef __AVX512F__
484
- // TODO(TJ): refine avx512
485
- INTRI8_FLOAT (jit::avx512f);
486
- INTRI16_FLOAT (jit::avx512f);
487
- INTRI_COMMON_FLOAT (jit::avx512f, kGT16 );
488
- #endif
489
-
490
- #undef INTRI8_FLOAT
491
- #undef INTRI16_FLOAT
492
- #undef INTRI_COMMON_FLOAT
493
-
494
441
REGISTER_JITKERNEL_DEPRECATED (vscal, VScalKernel);
495
442
REGISTER_JITKERNEL_DEPRECATED (vaddb, VAddBiasKernel);
496
443
REGISTER_JITKERNEL_DEPRECATED (vrelu, VReluKernel);
497
- REGISTER_JITKERNEL_DEPRECATED (vaddrelu, VAddReluKernel);
498
444
REGISTER_JITKERNEL_DEPRECATED (videntity, VIdentityKernel);
499
445
500
446
} // namespace jitkernel
0 commit comments