Skip to content

Commit b68ecec

Browse files
committed
add vaddrelu jitcode
test=develop
1 parent bb09e31 commit b68ecec

File tree

5 files changed

+66
-102
lines changed

5 files changed

+66
-102
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,16 @@ bool VAddJitCode::init(int d) { return MayIUse(avx); }
7070

7171
void VAddJitCode::generate() {
7272
int offset = 0;
73+
if (with_relu_) {
74+
vxorps(ymm_zero, ymm_zero, ymm_zero);
75+
}
7376
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
7477
vmovups(ymm_src1, ptr[param1 + offset]);
7578
vmovups(ymm_src2, ptr[param2 + offset]);
7679
vaddps(ymm_dst, ymm_src1, ymm_src2);
80+
if (with_relu_) {
81+
vmaxps(ymm_dst, ymm_zero, ymm_dst);
82+
}
7783
vmovups(ptr[param3 + offset], ymm_dst);
7884
offset += sizeof(float) * AVX_FLOAT_BLOCK;
7985
}
@@ -82,6 +88,9 @@ void VAddJitCode::generate() {
8288
vmovups(xmm_src1, ptr[param1 + offset]);
8389
vmovups(xmm_src2, ptr[param2 + offset]);
8490
vaddps(xmm_dst, xmm_src1, xmm_src2);
91+
if (with_relu_) {
92+
vmaxps(xmm_dst, xmm_zero, xmm_dst);
93+
}
8594
vmovups(ptr[param3 + offset], xmm_dst);
8695
offset += sizeof(float) * 4;
8796
rest -= 4;
@@ -90,6 +99,9 @@ void VAddJitCode::generate() {
9099
vmovq(xmm_src1, ptr[param1 + offset]);
91100
vmovq(xmm_src2, ptr[param2 + offset]);
92101
vaddps(xmm_dst, xmm_src1, xmm_src2);
102+
if (with_relu_) {
103+
vmaxps(xmm_dst, xmm_zero, xmm_dst);
104+
}
93105
vmovq(ptr[param3 + offset], xmm_dst);
94106
offset += sizeof(float) * 2;
95107
rest -= 2;
@@ -98,6 +110,9 @@ void VAddJitCode::generate() {
98110
vmovss(xmm_src1, ptr[param1 + offset]);
99111
vmovss(xmm_src2, ptr[param2 + offset]);
100112
vaddss(xmm_dst, xmm_src1, xmm_src2);
113+
if (with_relu_) {
114+
vmaxps(xmm_dst, xmm_zero, xmm_dst);
115+
}
101116
vmovss(ptr[param3 + offset], xmm_dst);
102117
}
103118
ret();

paddle/fluid/operators/math/jit_code.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,35 +46,38 @@ class VMulJitCode : public JitCode {
4646

4747
xmm_t xmm_src1 = xmm_t(0);
4848
xmm_t xmm_src2 = xmm_t(1);
49-
xmm_t xmm_dst = xmm_t(2);
49+
xmm_t xmm_dst = xmm_t(1);
5050

5151
ymm_t ymm_src1 = ymm_t(0);
5252
ymm_t ymm_src2 = ymm_t(1);
53-
ymm_t ymm_dst = ymm_t(2);
53+
ymm_t ymm_dst = ymm_t(1);
5454
};
5555

5656
class VAddJitCode : public JitCode {
5757
public:
5858
DECLARE_JIT_CODE(VAddJitCode);
59-
explicit VAddJitCode(int d, size_t code_size = 256 * 1024,
59+
explicit VAddJitCode(int d, bool with_relu, size_t code_size = 256 * 1024,
6060
void* code_ptr = nullptr)
61-
: JitCode(code_size, code_ptr), num_(d) {}
61+
: JitCode(code_size, code_ptr), num_(d), with_relu_(with_relu) {}
6262
static bool init(int d);
6363
void generate() override;
6464

6565
private:
6666
int num_;
67+
bool with_relu_;
6768
reg64_t param1{abi_param1};
6869
reg64_t param2{abi_param2};
6970
reg64_t param3{abi_param3};
7071

7172
xmm_t xmm_src1 = xmm_t(0);
7273
xmm_t xmm_src2 = xmm_t(1);
73-
xmm_t xmm_dst = xmm_t(2);
74+
xmm_t xmm_dst = xmm_t(1);
75+
xmm_t xmm_zero = xmm_t(2);
7476

7577
ymm_t ymm_src1 = ymm_t(0);
7678
ymm_t ymm_src2 = ymm_t(1);
77-
ymm_t ymm_dst = ymm_t(2);
79+
ymm_t ymm_dst = ymm_t(1);
80+
ymm_t ymm_zero = ymm_t(2);
7881
};
7982

8083
} // namespace gen

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,22 +75,22 @@ class VAddKernel : public Kernel {
7575
};
7676

7777
template <typename T>
78-
class VScalKernel : public Kernel {
78+
class VAddReluKernel : public Kernel {
7979
public:
80-
virtual void Compute(const T a, const T *x, T *y) const = 0;
81-
virtual void Compute(const T a, T *x) const = 0;
80+
void (*Compute)(const T *, const T *, T *, int);
8281
};
8382

8483
template <typename T>
85-
class VAddBiasKernel : public Kernel {
84+
class VScalKernel : public Kernel {
8685
public:
8786
virtual void Compute(const T a, const T *x, T *y) const = 0;
87+
virtual void Compute(const T a, T *x) const = 0;
8888
};
8989

9090
template <typename T>
91-
class VAddReluKernel : public Kernel {
91+
class VAddBiasKernel : public Kernel {
9292
public:
93-
virtual void Compute(const T *x, const T *y, T *z) const = 0;
93+
virtual void Compute(const T a, const T *x, T *y) const = 0;
9494
};
9595

9696
template <typename T>

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 35 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ void VAddRefer(const T* x, const T* y, T* z, int n) {
4646
}
4747
}
4848

49+
template <typename T>
50+
void VAddReluRefer(const T* x, const T* y, T* z, int n) {
51+
for (int i = 0; i < n; ++i) {
52+
z[i] = x[i] + y[i];
53+
z[i] = z[i] > 0 ? z[i] : 0;
54+
}
55+
}
56+
4957
#ifdef PADDLE_WITH_MKLML
5058
template <typename T>
5159
void VMulMKL(const T* x, const T* y, T* z, int n);
@@ -131,7 +139,7 @@ class VAddKernelImpl : public VAddKernel<T> {
131139
explicit VAddKernelImpl(int d) : VAddKernel<T>() {
132140
if (useJIT(d)) {
133141
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
134-
jitcode_.reset(new gen::VAddJitCode(d, sz > 4096 ? sz : 4096));
142+
jitcode_.reset(new gen::VAddJitCode(d, false, sz > 4096 ? sz : 4096));
135143
this->Compute =
136144
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
137145
return;
@@ -164,10 +172,36 @@ bool VAddKernelImpl<double>::useMKL(int d) {
164172
return true;
165173
}
166174

175+
/* VAddRelu JitKernel */
176+
template <typename T>
177+
class VAddReluKernelImpl : public VAddReluKernel<T> {
178+
public:
179+
DECLARE_STATIC_FUNC;
180+
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() {
181+
if (useJIT(d)) {
182+
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
183+
jitcode_.reset(new gen::VAddJitCode(d, true, sz > 4096 ? sz : 4096));
184+
this->Compute =
185+
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
186+
return;
187+
}
188+
this->Compute = VAddReluRefer<T>;
189+
}
190+
191+
private:
192+
std::unique_ptr<gen::VAddJitCode> jitcode_{nullptr};
193+
};
194+
195+
template <>
196+
bool VAddReluKernelImpl<float>::useJIT(int d) {
197+
return gen::VAddJitCode::init(d);
198+
}
199+
167200
#undef DECLARE_STATIC_FUNC
168201

169202
REGISTER_JITKERNEL(vmul, VMulKernel);
170203
REGISTER_JITKERNEL(vadd, VAddKernel);
204+
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
171205

172206
/* VSCAL JitKernel */
173207
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
@@ -404,97 +438,9 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
404438
void Compute(const T* x, T* y) const override {}
405439
};
406440

407-
/* VAddRelu JitKernel */
408-
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
409-
class VAddReluKernelImpl : public VAddReluKernel<T> {
410-
public:
411-
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() { this->num_ = d; }
412-
void Compute(const T* x, const T* y, T* z) const override {
413-
for (int i = 0; i < this->num_; ++i) {
414-
z[i] = x[i] + y[i];
415-
z[i] = z[i] > 0 ? z[i] : 0;
416-
}
417-
}
418-
};
419-
420-
#define INTRI8_FLOAT(isa) \
421-
template <> \
422-
void VAddReluKernelImpl<float, isa, kEQ8>::Compute( \
423-
const float* x, const float* y, float* z) const { \
424-
__m256 tmpx = _mm256_loadu_ps(x); \
425-
__m256 tmpy = _mm256_loadu_ps(y); \
426-
tmpy = _mm256_add_ps(tmpx, tmpy); \
427-
tmpy = _mm256_max_ps(tmpy, _mm256_setzero_ps()); \
428-
_mm256_storeu_ps(z, tmpy); \
429-
}
430-
431-
#define INTRI16_FLOAT(isa) \
432-
template <> \
433-
void VAddReluKernelImpl<float, isa, kEQ16>::Compute( \
434-
const float* x, const float* y, float* z) const { \
435-
__m256 zeros = _mm256_setzero_ps(); \
436-
__m256 tmp0 = _mm256_loadu_ps(x); \
437-
__m256 tmp1 = _mm256_loadu_ps(y); \
438-
tmp0 = _mm256_add_ps(tmp0, tmp1); \
439-
tmp0 = _mm256_max_ps(tmp0, zeros); \
440-
tmp1 = _mm256_loadu_ps(x + 8); \
441-
__m256 tmp2 = _mm256_loadu_ps(y + 8); \
442-
tmp1 = _mm256_add_ps(tmp1, tmp2); \
443-
tmp1 = _mm256_max_ps(tmp1, zeros); \
444-
_mm256_storeu_ps(z, tmp0); \
445-
_mm256_storeu_ps(z + 8, tmp1); \
446-
}
447-
448-
#define INTRI_COMMON_FLOAT(isa, block) \
449-
template <> \
450-
VAddReluKernelImpl<float, isa, block>::VAddReluKernelImpl(int d) \
451-
: VAddReluKernel<float>() { \
452-
this->num_ = d; \
453-
this->end_ = d - d % AVX_FLOAT_BLOCK; \
454-
this->rest_ = d - this->end_; \
455-
} \
456-
template <> \
457-
void VAddReluKernelImpl<float, isa, block>::Compute( \
458-
const float* x, const float* y, float* z) const { \
459-
__m256 zeros = _mm256_setzero_ps(); \
460-
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
461-
__m256 tmpx = _mm256_loadu_ps(x + i); \
462-
__m256 tmpy = _mm256_loadu_ps(y + i); \
463-
tmpy = _mm256_add_ps(tmpx, tmpy); \
464-
tmpy = _mm256_max_ps(tmpy, zeros); \
465-
_mm256_storeu_ps(z + i, tmpy); \
466-
} \
467-
for (int i = this->end_; i < this->num_; ++i) { \
468-
z[i] = x[i] + y[i]; \
469-
z[i] = z[i] > 0 ? z[i] : 0; \
470-
} \
471-
}
472-
473-
#ifdef __AVX__
474-
INTRI8_FLOAT(jit::avx);
475-
INTRI16_FLOAT(jit::avx);
476-
INTRI_COMMON_FLOAT(jit::avx, kGT16);
477-
#endif
478-
#ifdef __AVX2__
479-
INTRI8_FLOAT(jit::avx2);
480-
INTRI16_FLOAT(jit::avx2);
481-
INTRI_COMMON_FLOAT(jit::avx2, kGT16);
482-
#endif
483-
#ifdef __AVX512F__
484-
// TODO(TJ): refine avx512
485-
INTRI8_FLOAT(jit::avx512f);
486-
INTRI16_FLOAT(jit::avx512f);
487-
INTRI_COMMON_FLOAT(jit::avx512f, kGT16);
488-
#endif
489-
490-
#undef INTRI8_FLOAT
491-
#undef INTRI16_FLOAT
492-
#undef INTRI_COMMON_FLOAT
493-
494441
REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel);
495442
REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel);
496443
REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel);
497-
REGISTER_JITKERNEL_DEPRECATED(vaddrelu, VAddReluKernel);
498444
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);
499445

500446
} // namespace jitkernel

paddle/fluid/operators/math/jit_kernel_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ TEST(JitKernel, vaddrelu) {
757757
auto tmkle = GetCurrentUS();
758758
auto ttgts = GetCurrentUS();
759759
for (int i = 0; i < repeat; ++i) {
760-
ker->Compute(x_data, y_data, ztgt_data);
760+
ker->Compute(x_data, y_data, ztgt_data, d);
761761
}
762762
auto ttgte = GetCurrentUS();
763763
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat

0 commit comments

Comments
 (0)