Skip to content

Commit 03e11f3

Browse files
committed
add vscal jitcode
1 parent 5b7a9dd commit 03e11f3

File tree

6 files changed

+150
-85
lines changed

6 files changed

+150
-85
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,41 @@ void VVVJitCode::generate() {
9696
}
9797
ret();
9898
}
99+
100+
bool VScalJitCode::init(int d) { return MayIUse(avx); }
101+
102+
void VScalJitCode::generate() {
103+
int offset = 0;
104+
vbroadcastss(ymm_src1, ptr[param1]);
105+
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
106+
vmovups(ymm_src2, ptr[param2 + offset]);
107+
vmulps(ymm_dst, ymm_src1, ymm_src2);
108+
vmovups(ptr[param3 + offset], ymm_dst);
109+
offset += sizeof(float) * AVX_FLOAT_BLOCK;
110+
}
111+
int rest = num_ % AVX_FLOAT_BLOCK;
112+
if (rest >= 4) {
113+
vmovups(xmm_src2, ptr[param2 + offset]);
114+
vmulps(xmm_dst, xmm_src1, xmm_src2);
115+
vmovups(ptr[param3 + offset], xmm_dst);
116+
offset += sizeof(float) * 4;
117+
rest -= 4;
118+
}
119+
if (rest >= 2) {
120+
vmovq(xmm_src2, ptr[param2 + offset]);
121+
vmulps(xmm_dst, xmm_src1, xmm_src2);
122+
vmovq(ptr[param3 + offset], xmm_dst);
123+
offset += sizeof(float) * 2;
124+
rest -= 2;
125+
}
126+
if (rest > 0) {
127+
vmovss(xmm_src2, ptr[param2 + offset]);
128+
vmulss(xmm_dst, xmm_src1, xmm_src2);
129+
vmovss(ptr[param3 + offset], xmm_dst);
130+
}
131+
ret();
132+
}
133+
99134
} // namespace gen
100135
} // namespace jitkernel
101136
} // namespace math

paddle/fluid/operators/math/jit_code.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ using ymm_t = const Xbyak::Ymm;
2929
using zmm_t = const Xbyak::Zmm;
3030
using Label = Xbyak::Label;
3131

32-
// function: vec = Operand(vec, vec) (maybe with relu)
3332
typedef enum { mul = 0, add } operand_type;
3433

34+
// function: vec = Operand(vec, vec) (maybe with relu)
3535
class VVVJitCode : public JitCode {
3636
public:
3737
const char* name() const override {
@@ -41,7 +41,7 @@ class VVVJitCode : public JitCode {
4141
} else if (type_ == operand_type::add) {
4242
base += "_Add";
4343
}
44-
base += (with_relu_ ? "_relu" : "");
44+
base += (with_relu_ ? "_Relu" : "");
4545
return base.c_str();
4646
}
4747
explicit VVVJitCode(int d, operand_type type, bool with_relu,
@@ -72,6 +72,32 @@ class VVVJitCode : public JitCode {
7272
ymm_t ymm_zero = ymm_t(2);
7373
};
7474

75+
class VScalJitCode : public JitCode {
76+
public:
77+
DECLARE_JIT_CODE(VScalJitCode);
78+
explicit VScalJitCode(int d, size_t code_size = 256 * 1024,
79+
void* code_ptr = nullptr)
80+
: JitCode(code_size, code_ptr), num_(d) {}
81+
static bool init(int d);
82+
void generate() override;
83+
84+
private:
85+
int num_;
86+
reg64_t param1{abi_param1};
87+
reg64_t param2{abi_param2};
88+
reg64_t param3{abi_param3};
89+
90+
xmm_t xmm_src1 = xmm_t(0);
91+
xmm_t xmm_src2 = xmm_t(1);
92+
xmm_t xmm_dst = xmm_t(1);
93+
xmm_t xmm_zero = xmm_t(2);
94+
95+
ymm_t ymm_src1 = ymm_t(0);
96+
ymm_t ymm_src2 = ymm_t(1);
97+
ymm_t ymm_dst = ymm_t(1);
98+
ymm_t ymm_zero = ymm_t(2);
99+
};
100+
75101
} // namespace gen
76102
} // namespace jitkernel
77103
} // namespace math

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ class VAddReluKernel : public Kernel {
8383
template <typename T>
8484
class VScalKernel : public Kernel {
8585
public:
86-
virtual void Compute(const T a, const T *x, T *y) const = 0;
87-
virtual void Compute(const T a, T *x) const = 0;
86+
void (*Compute)(const T *, const T *, T *, int);
8887
};
8988

9089
template <typename T>

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 72 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ void VAddReluRefer(const T* x, const T* y, T* z, int n) {
5757
}
5858
}
5959

60+
template <typename T>
61+
void VScalRefer(const T* a, const T* x, T* y, int n) {
62+
for (int i = 0; i < n; ++i) {
63+
y[i] = a[0] * x[i];
64+
}
65+
}
66+
6067
#ifdef PADDLE_WITH_MKLML
6168
template <typename T>
6269
void VMulMKL(const T* x, const T* y, T* z, int n);
@@ -83,6 +90,28 @@ template <>
8390
void VAddMKL<double>(const double* x, const double* y, double* z, int n) {
8491
platform::dynload::vdAdd(n, x, y, z);
8592
}
93+
94+
template <typename T>
95+
void VScalMKL(const T* a, const T* x, T* y, int n);
96+
97+
template <>
98+
void VScalMKL<float>(const float* a, const float* x, float* y, int n) {
99+
if (x == y) {
100+
platform::dynload::cblas_sscal(n, *a, y, 1);
101+
} else {
102+
VScalRefer<float>(a, x, y, n);
103+
}
104+
}
105+
106+
template <>
107+
void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
108+
if (x == y) {
109+
platform::dynload::cblas_dscal(n, *a, y, 1);
110+
} else {
111+
VScalRefer<double>(a, x, y, n);
112+
}
113+
}
114+
86115
#endif
87116

88117
#define DECLARE_STATIC_FUNC \
@@ -226,87 +255,60 @@ bool VAddReluKernelImpl<float>::useJIT(int d) {
226255
}
227256
#endif
228257

229-
#undef DECLARE_STATIC_FUNC
230-
231-
REGISTER_JITKERNEL(vmul, VMulKernel);
232-
REGISTER_JITKERNEL(vadd, VAddKernel);
233-
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
234-
235-
/* VSCAL JitKernel */
236-
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
258+
/* VScal JitKernel */
259+
template <typename T>
237260
class VScalKernelImpl : public VScalKernel<T> {
238261
public:
239-
explicit VScalKernelImpl(int d) : VScalKernel<T>() { this->num_ = d; }
240-
void Compute(const T a, const T* x, T* y) const override {
241-
for (int i = 0; i < this->num_; ++i) {
242-
y[i] = a * x[i];
243-
}
244-
}
245-
void Compute(const T a, T* x) const override {
246-
for (int i = 0; i < this->num_; ++i) {
247-
x[i] = a * x[i];
262+
DECLARE_STATIC_FUNC;
263+
explicit VScalKernelImpl(int d) : VScalKernel<T>() {
264+
#ifdef PADDLE_WITH_XBYAK
265+
if (useJIT(d)) {
266+
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
267+
jitcode_.reset(new gen::VScalJitCode(d, sz > 4096 ? sz : 4096));
268+
this->Compute =
269+
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
270+
return;
248271
}
249-
}
250-
};
251-
272+
#endif
252273
#ifdef PADDLE_WITH_MKLML
253-
#define MKL_FLOAT(isa, block) \
254-
template <> \
255-
void VScalKernelImpl<float, isa, block>::Compute(const float a, float* x) \
256-
const { \
257-
platform::dynload::cblas_sscal(this->num_, a, x, 1); \
258-
}
259-
260-
#define MKL_DOUBLE(isa, block) \
261-
template <> \
262-
void VScalKernelImpl<double, isa, block>::Compute(const double a, double* x) \
263-
const { \
264-
platform::dynload::cblas_dscal(this->num_, a, x, 1); \
265-
}
266-
267-
FOR_EACH_ISA(MKL_FLOAT, kGT16);
268-
FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
274+
if (useMKL(d)) {
275+
this->Compute = VScalMKL<T>;
276+
return;
277+
}
269278
#endif
270-
271-
#define INTRI8_FLOAT(isa) \
272-
template <> \
273-
void VScalKernelImpl<float, isa, kEQ8>::Compute( \
274-
const float a, const float* x, float* y) const { \
275-
__m256 tmp; \
276-
__m256 scalar = _mm256_set1_ps(a); \
277-
tmp = _mm256_loadu_ps(x); \
278-
tmp = _mm256_mul_ps(tmp, scalar); \
279-
_mm256_storeu_ps(y, tmp); \
280-
}
281-
#define INTRI8_INPLACE_FLOAT(isa) \
282-
template <> \
283-
void VScalKernelImpl<float, isa, kEQ8>::Compute(const float a, float* x) \
284-
const { \
285-
__m256 tmp; \
286-
__m256 scalar = _mm256_set1_ps(a); \
287-
tmp = _mm256_loadu_ps(x); \
288-
tmp = _mm256_mul_ps(tmp, scalar); \
289-
_mm256_storeu_ps(x, tmp); \
279+
this->Compute = VScalRefer<T>;
290280
}
281+
#ifdef PADDLE_WITH_XBYAK
291282

292-
#ifdef __AVX__
293-
INTRI8_FLOAT(jit::avx);
294-
INTRI8_INPLACE_FLOAT(jit::avx);
283+
private:
284+
std::unique_ptr<gen::VScalJitCode> jitcode_{nullptr};
295285
#endif
296-
#ifdef __AVX2__
297-
INTRI8_FLOAT(jit::avx2);
298-
INTRI8_INPLACE_FLOAT(jit::avx2);
286+
};
287+
288+
#ifdef PADDLE_WITH_XBYAK
289+
template <>
290+
bool VScalKernelImpl<float>::useJIT(int d) {
291+
return gen::VScalJitCode::init(d);
292+
}
299293
#endif
300-
#ifdef __AVX512F__
301-
INTRI8_FLOAT(jit::avx512f);
302-
INTRI8_INPLACE_FLOAT(jit::avx512f);
294+
295+
#ifdef PADDLE_WITH_MKLML
296+
template <>
297+
bool VScalKernelImpl<float>::useMKL(int d) {
298+
return d > 512;
299+
}
300+
template <>
301+
bool VScalKernelImpl<double>::useMKL(int d) {
302+
return true;
303+
}
303304
#endif
304-
// TODO(TJ): eq16 test and complete avx512
305305

306-
#undef INTRI8_FLOAT
307-
#undef INTRI8_INPLACE_FLOAT
308-
#undef MKL_FLOAT
309-
#undef MKL_DOUBLE
306+
#undef DECLARE_STATIC_FUNC
307+
308+
REGISTER_JITKERNEL(vmul, VMulKernel);
309+
REGISTER_JITKERNEL(vadd, VAddKernel);
310+
REGISTER_JITKERNEL(vscal, VScalKernel);
311+
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
310312

311313
/* VAddBias JitKernel */
312314
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
@@ -467,7 +469,6 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
467469
void Compute(const T* x, T* y) const override {}
468470
};
469471

470-
REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel);
471472
REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel);
472473
REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel);
473474
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);

paddle/fluid/operators/math/jit_kernel_exp.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
409409
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d);
410410
}
411411
void Compute(const T* x, T* y) const override {
412-
vscal_->Compute(static_cast<T>(2), x, y);
412+
const T a = static_cast<T>(2);
413+
vscal_->Compute(&a, x, y, this->num_);
413414
vsigmoid_->Compute(y, y);
414-
vscal_->Compute(static_cast<T>(2), y);
415+
vscal_->Compute(&a, y, y, this->num_);
415416
vaddbias_->Compute(static_cast<T>(-1), y, y);
416417
}
417418

@@ -472,9 +473,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
472473
_mm256_storeu_ps(y, tmp); \
473474
x += AVX_FLOAT_BLOCK; \
474475
y += AVX_FLOAT_BLOCK; \
475-
vscal_->Compute(2.f, x, y); \
476+
const float a = 2.f; \
477+
vscal_->Compute(&a, x, y, this->num_); \
476478
vsigmoid_->Compute(y, y); \
477-
vscal_->Compute(2.f, y); \
479+
vscal_->Compute(&a, y, y, this->num_); \
478480
vaddbias_->Compute(-1.f, y, y); \
479481
}
480482

@@ -502,9 +504,10 @@ class VTanhKernelImpl : public VTanhKernel<T> {
502504
} \
503505
x += this->end_; \
504506
y += this->end_; \
505-
vscal_->Compute(2.f, x, y); \
507+
const float a = 2.f; \
508+
vscal_->Compute(&a, x, y, this->num_); \
506509
vsigmoid_->Compute(y, y); \
507-
vscal_->Compute(2.f, y); \
510+
vscal_->Compute(&a, y, y, this->num_); \
508511
vaddbias_->Compute(-1.f, y, y); \
509512
}
510513

paddle/fluid/operators/math/jit_kernel_test.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,10 @@ void vtanh_better(
281281
const paddle::operators::math::jitkernel::VAddBiasKernel<float>>&
282282
vaddbias,
283283
const int n, const float* x, float* y) {
284-
vscal->Compute(2.f, x, y);
284+
const float tmp1 = 2.f;
285+
vscal->Compute(&tmp1, x, y, n);
285286
vsigmoid->Compute(y, y);
286-
vscal->Compute(2.f, y);
287+
vscal->Compute(&tmp1, y, y, n);
287288
vaddbias->Compute(-1.f, y, y);
288289
}
289290

@@ -531,12 +532,12 @@ TEST(JitKernel, vscal) {
531532

532533
auto ttgts = GetCurrentUS();
533534
for (int i = 0; i < repeat; ++i) {
534-
ker->Compute(a, x_data, ztgt_data);
535+
ker->Compute(&a, x_data, ztgt_data, d);
535536
}
536537
auto ttgte = GetCurrentUS();
537538
auto ttgts1 = GetCurrentUS();
538539
for (int i = 0; i < repeat; ++i) {
539-
ker->Compute(a, y_data);
540+
ker->Compute(&a, y_data, y_data, d);
540541
}
541542
auto ttgte1 = GetCurrentUS();
542543
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat

0 commit comments

Comments
 (0)