Skip to content

Commit 0043c42

Browse files
committed
add vrelu jitcode
test=develop
1 parent 9a6e239 commit 0043c42

File tree

7 files changed

+245
-245
lines changed

7 files changed

+245
-245
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,39 @@ void VXXJitCode::generate() {
118118
ret();
119119
}
120120

121+
bool ReluJitCode::init(int d) { return MayIUse(avx); }
122+
123+
void ReluJitCode::generate() {
124+
int offset = 0;
125+
vxorps(ymm_zero, ymm_zero, ymm_zero);
126+
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
127+
vmovups(ymm_src, ptr[param1 + offset]);
128+
vmaxps(ymm_dst, ymm_zero, ymm_src);
129+
vmovups(ptr[param2 + offset], ymm_dst);
130+
offset += sizeof(float) * AVX_FLOAT_BLOCK;
131+
}
132+
int rest = num_ % AVX_FLOAT_BLOCK;
133+
if (rest >= 4) {
134+
vmovups(xmm_src, ptr[param1 + offset]);
135+
vmaxps(xmm_dst, xmm_zero, xmm_src);
136+
vmovups(ptr[param2 + offset], xmm_dst);
137+
offset += sizeof(float) * 4;
138+
rest -= 4;
139+
}
140+
if (rest >= 2) {
141+
vmovups(xmm_src, ptr[param1 + offset]);
142+
vmaxps(xmm_dst, xmm_zero, xmm_src);
143+
vmovq(ptr[param2 + offset], xmm_dst);
144+
offset += sizeof(float) * 2;
145+
rest -= 2;
146+
}
147+
if (rest > 0) {
148+
vmovups(xmm_src, ptr[param1 + offset]);
149+
vmaxps(xmm_dst, xmm_zero, xmm_src);
150+
vmovss(ptr[param2 + offset], xmm_dst);
151+
}
152+
ret();
153+
}
121154
} // namespace gen
122155
} // namespace jitkernel
123156
} // namespace math

paddle/fluid/operators/math/jit_code.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,29 @@ class VXXJitCode : public JitCode {
8585
ymm_t ymm_zero = ymm_t(3);
8686
};
8787

88+
class ReluJitCode : public JitCode {
89+
public:
90+
DECLARE_JIT_CODE(ReluJitCode);
91+
explicit ReluJitCode(int d, size_t code_size = 256 * 1024,
92+
void* code_ptr = nullptr)
93+
: JitCode(code_size, code_ptr), num_(d) {}
94+
static bool init(int d);
95+
void generate() override;
96+
97+
private:
98+
int num_;
99+
reg64_t param1{abi_param1};
100+
reg64_t param2{abi_param2};
101+
102+
xmm_t xmm_zero = xmm_t(0);
103+
xmm_t xmm_src = xmm_t(1);
104+
xmm_t xmm_dst = xmm_t(1);
105+
106+
ymm_t ymm_zero = ymm_t(0);
107+
ymm_t ymm_src = ymm_t(1);
108+
ymm_t ymm_dst = ymm_t(1);
109+
};
110+
88111
} // namespace gen
89112
} // namespace jitkernel
90113
} // namespace math

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,37 +97,38 @@ class VAddBiasKernel : public Kernel {
9797
template <typename T>
9898
class VActKernel : public Kernel {
9999
public:
100-
virtual void Compute(const T *x, T *y) const = 0;
100+
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
101101
};
102102

103103
template <typename T>
104104
class VReluKernel : public VActKernel<T> {
105105
public:
106-
virtual void Compute(const T *x, T *y) const = 0;
106+
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
107+
void (*Compute)(const T *, T *, int);
107108
};
108109

109110
template <typename T>
110111
class VIdentityKernel : public VActKernel<T> {
111112
public:
112-
virtual void Compute(const T *x, T *y) const = 0;
113+
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
113114
};
114115

115116
template <typename T>
116117
class VExpKernel : public VActKernel<T> {
117118
public:
118-
virtual void Compute(const T *x, T *y) const = 0;
119+
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
119120
};
120121

121122
template <typename T>
122123
class VSigmoidKernel : public VActKernel<T> {
123124
public:
124-
virtual void Compute(const T *x, T *y) const = 0;
125+
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
125126
};
126127

127128
template <typename T>
128129
class VTanhKernel : public VActKernel<T> {
129130
public:
130-
virtual void Compute(const T *x, T *y) const = 0;
131+
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
131132
};
132133

133134
template <typename T>

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 42 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ void VAddBiasRefer(const T* a, const T* x, T* y, int n) {
7171
}
7272
}
7373

74+
template <typename T>
75+
void VReluRefer(const T* x, T* y, int n) {
76+
for (int i = 0; i < n; ++i) {
77+
y[i] = x[i] > 0 ? x[i] : 0;
78+
}
79+
}
80+
7481
#ifdef PADDLE_WITH_MKLML
7582
template <typename T>
7683
void VMulMKL(const T* x, const T* y, T* z, int n);
@@ -344,124 +351,60 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
344351
}
345352
#endif
346353

347-
#undef DECLARE_STATIC_FUNC
348-
349-
REGISTER_JITKERNEL(vmul, VMulKernel);
350-
REGISTER_JITKERNEL(vadd, VAddKernel);
351-
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
352-
REGISTER_JITKERNEL(vscal, VScalKernel);
353-
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
354-
355354
/* VRelu JitKernel */
356-
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
355+
template <typename T>
357356
class VReluKernelImpl : public VReluKernel<T> {
358357
public:
359-
explicit VReluKernelImpl(int d) : VReluKernel<T>() { this->num_ = d; }
360-
void Compute(const T* x, T* y) const override {
361-
for (int i = 0; i < this->num_; ++i) {
362-
y[i] = x[i] > 0 ? x[i] : 0;
358+
DECLARE_STATIC_FUNC;
359+
explicit VReluKernelImpl(int d) : VReluKernel<T>() {
360+
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
361+
#ifdef PADDLE_WITH_XBYAK
362+
if (useJIT(d)) {
363+
size_t sz = 96 /*init*/ +
364+
d / AVX_FLOAT_BLOCK * 4 /* instructions*/ *
365+
8 /*everage byte for each instruction*/;
366+
jitcode_.reset(new gen::ReluJitCode(d, sz > 4096 ? sz : 4096));
367+
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
368+
return;
363369
}
364-
}
365-
};
366-
367-
#define INTRI8_FLOAT(isa) \
368-
template <> \
369-
void VReluKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
370-
const { \
371-
__m256 tmp = _mm256_loadu_ps(x); \
372-
tmp = _mm256_max_ps(tmp, _mm256_setzero_ps()); \
373-
_mm256_storeu_ps(y, tmp); \
374-
}
375-
376-
#define INTRI16_FLOAT(isa) \
377-
template <> \
378-
void VReluKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
379-
const { \
380-
__m256 zeros = _mm256_setzero_ps(); \
381-
__m256 tmp0 = _mm256_loadu_ps(x); \
382-
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
383-
tmp0 = _mm256_max_ps(tmp0, zeros); \
384-
tmp1 = _mm256_max_ps(tmp1, zeros); \
385-
_mm256_storeu_ps(y, tmp0); \
386-
_mm256_storeu_ps(y + 8, tmp1); \
387-
}
370+
#endif
388371

389-
#define INTRI_GT8LT16_FLOAT(isa) \
390-
template <> \
391-
VReluKernelImpl<float, isa, kGT8LT16>::VReluKernelImpl(int d) \
392-
: VReluKernel<float>() { \
393-
this->num_ = d; \
394-
this->end_ = AVX_FLOAT_BLOCK; \
395-
this->rest_ = d - AVX_FLOAT_BLOCK; \
396-
} \
397-
template <> \
398-
void VReluKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \
399-
float* y) const { \
400-
__m256 zeros = _mm256_setzero_ps(); \
401-
__m256 tmp0 = _mm256_loadu_ps(x); \
402-
__m256 tmp1 = _mm256_loadu_ps(x + this->rest_); \
403-
tmp0 = _mm256_max_ps(tmp0, zeros); \
404-
tmp1 = _mm256_max_ps(tmp1, zeros); \
405-
_mm256_storeu_ps(y, tmp0); \
406-
_mm256_storeu_ps(y + this->rest_, tmp1); \
372+
this->Compute = VReluRefer<T>;
407373
}
408-
409-
#define INTRI_GT16_FLOAT(isa) \
410-
template <> \
411-
VReluKernelImpl<float, isa, kGT16>::VReluKernelImpl(int d) \
412-
: VReluKernel<float>() { \
413-
this->num_ = d; \
414-
this->end_ = d - d % AVX_FLOAT_BLOCK; \
415-
this->rest_ = d - AVX_FLOAT_BLOCK; \
416-
} \
417-
template <> \
418-
void VReluKernelImpl<float, isa, kGT16>::Compute(const float* x, float* y) \
419-
const { \
420-
__m256 zeros = _mm256_setzero_ps(); \
421-
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
422-
__m256 tmp = _mm256_loadu_ps(x + i); \
423-
tmp = _mm256_max_ps(tmp, zeros); \
424-
_mm256_storeu_ps(y + i, tmp); \
425-
} \
426-
__m256 tmp = _mm256_loadu_ps(x + this->rest_); \
427-
tmp = _mm256_max_ps(tmp, zeros); \
428-
_mm256_storeu_ps(y + this->rest_, tmp); \
374+
void ComputeDeprecated(const T* x, T* y) const override {
375+
VReluRefer(x, y, this->num_);
429376
}
377+
#ifdef PADDLE_WITH_XBYAK
430378

431-
#ifdef __AVX__
432-
INTRI8_FLOAT(jit::avx);
433-
INTRI16_FLOAT(jit::avx);
434-
INTRI_GT8LT16_FLOAT(jit::avx);
435-
INTRI_GT16_FLOAT(jit::avx);
436-
#endif
437-
#ifdef __AVX2__
438-
INTRI8_FLOAT(jit::avx2);
439-
INTRI16_FLOAT(jit::avx2);
440-
INTRI_GT8LT16_FLOAT(jit::avx2);
441-
INTRI_GT16_FLOAT(jit::avx2);
379+
private:
380+
std::unique_ptr<gen::ReluJitCode> jitcode_{nullptr};
442381
#endif
443-
#ifdef __AVX512F__
444-
// TODO(TJ): refine avx512
445-
INTRI8_FLOAT(jit::avx512f);
446-
INTRI16_FLOAT(jit::avx512f);
447-
INTRI_GT8LT16_FLOAT(jit::avx512f);
448-
INTRI_GT16_FLOAT(jit::avx512f);
382+
};
383+
384+
#ifdef PADDLE_WITH_XBYAK
385+
template <>
386+
bool VReluKernelImpl<float>::useJIT(int d) {
387+
return gen::ReluJitCode::init(d);
388+
}
449389
#endif
450390

451-
#undef INTRI8_FLOAT
452-
#undef INTRI16_FLOAT
453-
#undef INTRI_GT8LT16_FLOAT
454-
#undef INTRI_GT16_FLOAT
391+
#undef DECLARE_STATIC_FUNC
392+
393+
REGISTER_JITKERNEL(vmul, VMulKernel);
394+
REGISTER_JITKERNEL(vadd, VAddKernel);
395+
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
396+
REGISTER_JITKERNEL(vscal, VScalKernel);
397+
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
398+
REGISTER_JITKERNEL(vrelu, VReluKernel);
455399

456400
/* An empty JitKernel */
457401
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
458402
class VIdentityKernelImpl : public VIdentityKernel<T> {
459403
public:
460404
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() { this->num_ = d; }
461-
void Compute(const T* x, T* y) const override {}
405+
void ComputeDeprecated(const T* x, T* y) const override {}
462406
};
463407

464-
REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel);
465408
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);
466409

467410
} // namespace jitkernel

0 commit comments

Comments
 (0)