Skip to content

Commit bb09e31

Browse files
committed
add vadd jitcode
test=develop
1 parent d55481c commit bb09e31

File tree

6 files changed

+135
-65
lines changed

6 files changed

+135
-65
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,42 @@ void VMulJitCode::generate() {
6666
ret();
6767
}
6868

69+
bool VAddJitCode::init(int d) { return MayIUse(avx); }
70+
71+
void VAddJitCode::generate() {
72+
int offset = 0;
73+
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
74+
vmovups(ymm_src1, ptr[param1 + offset]);
75+
vmovups(ymm_src2, ptr[param2 + offset]);
76+
vaddps(ymm_dst, ymm_src1, ymm_src2);
77+
vmovups(ptr[param3 + offset], ymm_dst);
78+
offset += sizeof(float) * AVX_FLOAT_BLOCK;
79+
}
80+
int rest = num_ % AVX_FLOAT_BLOCK;
81+
if (rest >= 4) {
82+
vmovups(xmm_src1, ptr[param1 + offset]);
83+
vmovups(xmm_src2, ptr[param2 + offset]);
84+
vaddps(xmm_dst, xmm_src1, xmm_src2);
85+
vmovups(ptr[param3 + offset], xmm_dst);
86+
offset += sizeof(float) * 4;
87+
rest -= 4;
88+
}
89+
if (rest >= 2) {
90+
vmovq(xmm_src1, ptr[param1 + offset]);
91+
vmovq(xmm_src2, ptr[param2 + offset]);
92+
vaddps(xmm_dst, xmm_src1, xmm_src2);
93+
vmovq(ptr[param3 + offset], xmm_dst);
94+
offset += sizeof(float) * 2;
95+
rest -= 2;
96+
}
97+
if (rest > 0) {
98+
vmovss(xmm_src1, ptr[param1 + offset]);
99+
vmovss(xmm_src2, ptr[param2 + offset]);
100+
vaddss(xmm_dst, xmm_src1, xmm_src2);
101+
vmovss(ptr[param3 + offset], xmm_dst);
102+
}
103+
ret();
104+
}
69105
} // namespace gen
70106
} // namespace jitkernel
71107
} // namespace math

paddle/fluid/operators/math/jit_code.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,30 @@ class VMulJitCode : public JitCode {
5353
ymm_t ymm_dst = ymm_t(2);
5454
};
5555

56+
class VAddJitCode : public JitCode {
57+
public:
58+
DECLARE_JIT_CODE(VAddJitCode);
59+
explicit VAddJitCode(int d, size_t code_size = 256 * 1024,
60+
void* code_ptr = nullptr)
61+
: JitCode(code_size, code_ptr), num_(d) {}
62+
static bool init(int d);
63+
void generate() override;
64+
65+
private:
66+
int num_;
67+
reg64_t param1{abi_param1};
68+
reg64_t param2{abi_param2};
69+
reg64_t param3{abi_param3};
70+
71+
xmm_t xmm_src1 = xmm_t(0);
72+
xmm_t xmm_src2 = xmm_t(1);
73+
xmm_t xmm_dst = xmm_t(2);
74+
75+
ymm_t ymm_src1 = ymm_t(0);
76+
ymm_t ymm_src2 = ymm_t(1);
77+
ymm_t ymm_dst = ymm_t(2);
78+
};
79+
5680
} // namespace gen
5781
} // namespace jitkernel
5882
} // namespace math

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class VMulKernel : public Kernel {
7171
template <typename T>
7272
class VAddKernel : public Kernel {
7373
public:
74-
virtual void Compute(const T *x, const T *y, T *z) const = 0;
74+
void (*Compute)(const T *, const T *, T *, int);
7575
};
7676

7777
template <typename T>

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 64 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ void VMulRefer(const T* x, const T* y, T* z, int n) {
3939
}
4040
}
4141

42+
template <typename T>
43+
void VAddRefer(const T* x, const T* y, T* z, int n) {
44+
for (int i = 0; i < n; ++i) {
45+
z[i] = x[i] + y[i];
46+
}
47+
}
48+
4249
#ifdef PADDLE_WITH_MKLML
4350
template <typename T>
4451
void VMulMKL(const T* x, const T* y, T* z, int n);
@@ -47,22 +54,38 @@ template <>
4754
void VMulMKL<float>(const float* x, const float* y, float* z, int n) {
4855
platform::dynload::vsMul(n, x, y, z);
4956
}
57+
5058
template <>
5159
void VMulMKL<double>(const double* x, const double* y, double* z, int n) {
5260
platform::dynload::vdMul(n, x, y, z);
5361
}
62+
63+
template <typename T>
64+
void VAddMKL(const T* x, const T* y, T* z, int n);
65+
66+
template <>
67+
void VAddMKL<float>(const float* x, const float* y, float* z, int n) {
68+
platform::dynload::vsAdd(n, x, y, z);
69+
}
70+
71+
template <>
72+
void VAddMKL<double>(const double* x, const double* y, double* z, int n) {
73+
platform::dynload::vdAdd(n, x, y, z);
74+
}
5475
#endif
5576

77+
#define DECLARE_STATIC_FUNC \
78+
static inline std::string name(int d) { \
79+
PADDLE_THROW("DType should be either float or double"); \
80+
} \
81+
static inline bool useJIT(int d) { return false; } \
82+
static inline bool useMKL(int d) { return false; }
83+
5684
/* VMUL JitKernel */
5785
template <typename T>
5886
class VMulKernelImpl : public VMulKernel<T> {
5987
public:
60-
static inline std::string name(int d) {
61-
PADDLE_THROW("DType should be either float or double");
62-
}
63-
static inline bool useJIT(int d) { return false; }
64-
static inline bool useMKL(int d) { return false; }
65-
88+
DECLARE_STATIC_FUNC;
6689
explicit VMulKernelImpl(int d) : VMulKernel<T>() {
6790
if (useJIT(d)) {
6891
// roughly estimate the size of code
@@ -100,63 +123,51 @@ bool VMulKernelImpl<double>::useMKL(int d) {
100123
return true;
101124
}
102125

103-
REGISTER_JITKERNEL(vmul, VMulKernel);
104-
105-
/* VADD JitKernel */
106-
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
126+
/* VAdd JitKernel */
127+
template <typename T>
107128
class VAddKernelImpl : public VAddKernel<T> {
108129
public:
109-
explicit VAddKernelImpl(int d) : VAddKernel<T>() { this->num_ = d; }
110-
void Compute(const T* x, const T* y, T* z) const override {
111-
for (int i = 0; i < this->num_; ++i) {
112-
z[i] = x[i] + y[i];
130+
DECLARE_STATIC_FUNC;
131+
explicit VAddKernelImpl(int d) : VAddKernel<T>() {
132+
if (useJIT(d)) {
133+
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
134+
jitcode_.reset(new gen::VAddJitCode(d, sz > 4096 ? sz : 4096));
135+
this->Compute =
136+
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
137+
return;
113138
}
139+
#ifdef PADDLE_WITH_MKLML
140+
if (useMKL(d)) {
141+
this->Compute = VAddMKL<T>;
142+
return;
143+
}
144+
#endif
145+
this->Compute = VAddRefer<T>;
114146
}
147+
148+
private:
149+
std::unique_ptr<gen::VAddJitCode> jitcode_{nullptr};
115150
};
116151

117-
#ifdef PADDLE_WITH_MKLML
118-
#define MKL_FLOAT(isa, block) \
119-
template <> \
120-
void VAddKernelImpl<float, isa, block>::Compute( \
121-
const float* x, const float* y, float* z) const { \
122-
platform::dynload::vsAdd(this->num_, x, y, z); \
123-
}
152+
template <>
153+
bool VAddKernelImpl<float>::useJIT(int d) {
154+
return gen::VAddJitCode::init(d);
155+
}
124156

125-
#define MKL_DOUBLE(isa, block) \
126-
template <> \
127-
void VAddKernelImpl<double, isa, block>::Compute( \
128-
const double* x, const double* y, double* z) const { \
129-
platform::dynload::vdAdd(this->num_, x, y, z); \
130-
}
157+
template <>
158+
bool VAddKernelImpl<float>::useMKL(int d) {
159+
return d > 512;
160+
}
131161

132-
FOR_EACH_ISA(MKL_FLOAT, kGT16);
133-
FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
134-
#endif
162+
template <>
163+
bool VAddKernelImpl<double>::useMKL(int d) {
164+
return true;
165+
}
135166

136-
#define INTRI8_FLOAT(isa) \
137-
template <> \
138-
void VAddKernelImpl<float, isa, kEQ8>::Compute( \
139-
const float* x, const float* y, float* z) const { \
140-
__m256 tmpx, tmpy; \
141-
tmpx = _mm256_loadu_ps(x); \
142-
tmpy = _mm256_loadu_ps(y); \
143-
tmpx = _mm256_add_ps(tmpx, tmpy); \
144-
_mm256_storeu_ps(z, tmpx); \
145-
}
146-
#ifdef __AVX__
147-
INTRI8_FLOAT(jit::avx);
148-
#endif
149-
#ifdef __AVX2__
150-
INTRI8_FLOAT(jit::avx2);
151-
#endif
152-
#ifdef __AVX512F__
153-
INTRI8_FLOAT(jit::avx512f);
154-
#endif
155-
// TODO(TJ): eq16 test and complete avx512
167+
#undef DECLARE_STATIC_FUNC
156168

157-
#undef INTRI8_FLOAT
158-
#undef MKL_FLOAT
159-
#undef MKL_DOUBLE
169+
REGISTER_JITKERNEL(vmul, VMulKernel);
170+
REGISTER_JITKERNEL(vadd, VAddKernel);
160171

161172
/* VSCAL JitKernel */
162173
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
@@ -480,7 +491,6 @@ INTRI_COMMON_FLOAT(jit::avx512f, kGT16);
480491
#undef INTRI16_FLOAT
481492
#undef INTRI_COMMON_FLOAT
482493

483-
REGISTER_JITKERNEL_DEPRECATED(vadd, VAddKernel);
484494
REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel);
485495
REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel);
486496
REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel);

paddle/fluid/operators/math/jit_kernel_rnn.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class LSTMKernelImpl : public LSTMKernel<T> {
181181
act_cand_d_->Compute(gates, gates);
182182
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
183183
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
184-
vadd_d_->Compute(gates + d_, gates + d2_, ct);
184+
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
185185

186186
/* H_t = act_cell(C_t) * ogated */
187187
act_cell_d_->Compute(ct, gates + d2_);
@@ -291,16 +291,16 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
291291
/* get fgated and igated*/
292292
vmul_d_->Compute(wp_data, ct_1, checked, d_);
293293
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
294-
vadd_d2_->Compute(checked, gates + d_, gates + d_);
294+
vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_);
295295
act_gate_d2_->Compute(gates + d_, gates + d_);
296296
/* C_t = C_t-1 * fgated + cand_gated * igated*/
297297
act_cand_d_->Compute(gates, gates);
298298
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
299299
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
300-
vadd_d_->Compute(gates + d_, gates + d2_, ct);
300+
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
301301
/* get ogated*/
302302
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
303-
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_);
303+
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
304304
act_gate_d_->Compute(gates + d3_, gates + d3_);
305305
/* H_t = act_cell(C_t) * ogated */
306306
act_cell_d_->Compute(ct, gates + d2_);
@@ -314,7 +314,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
314314
vmul_d_->Compute(gates, gates + d_, ct, d_);
315315
/* get outgated, put W_oc * C_t on igated */
316316
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
317-
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_);
317+
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
318318
/* H_t = act_cell(C_t) * ogated */
319319
act_gate_d_->Compute(gates + d3_, gates + d3_);
320320
act_cell_d_->Compute(ct, gates + d2_);

paddle/fluid/operators/math/jit_kernel_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ void lstm_ctht_better(
371371
vtanh_d->Compute(gates, gates);
372372
vmul_d->Compute(gates, gates + d, gates + d, d);
373373
vmul_d->Compute(ct_1, gates + d2, gates + d2, d);
374-
vadd_d->Compute(gates + d, gates + d2, ct);
374+
vadd_d->Compute(gates + d, gates + d2, ct, d);
375375
/* H_t = act_cell(C_t) * ogated */
376376
vtanh_d->Compute(ct, gates + d2);
377377
vmul_d->Compute(gates + d2, gates + d * 3, ht, d);
@@ -695,7 +695,7 @@ TEST(JitKernel, vadd) {
695695

696696
auto ttgts = GetCurrentUS();
697697
for (int i = 0; i < repeat; ++i) {
698-
ker->Compute(x_data, y_data, ztgt_data);
698+
ker->Compute(x_data, y_data, ztgt_data, d);
699699
}
700700
auto ttgte = GetCurrentUS();
701701

@@ -723,8 +723,8 @@ void vaddrelu_better(
723723
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd,
724724
const std::shared_ptr<
725725
const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu,
726-
const float* x, const float* y, float* z) {
727-
vadd->Compute(x, y, z);
726+
const float* x, const float* y, float* z, int d) {
727+
vadd->Compute(x, y, z, d);
728728
vrelu->Compute(z, z);
729729
}
730730

@@ -752,7 +752,7 @@ TEST(JitKernel, vaddrelu) {
752752
auto trefe = GetCurrentUS();
753753
auto tmkls = GetCurrentUS();
754754
for (int i = 0; i < repeat; ++i) {
755-
vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data);
755+
vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data, d);
756756
}
757757
auto tmkle = GetCurrentUS();
758758
auto ttgts = GetCurrentUS();

0 commit comments

Comments
 (0)