Skip to content

Commit 5e64244

Browse files
committed
add vaddbias jitcode
test=develop
1 parent 5f7956a commit 5e64244

File tree

5 files changed

+62
-60
lines changed

5 files changed

+62
-60
lines changed

paddle/fluid/operators/math/jit_code.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,26 @@ using Label = Xbyak::Label;
3131

3232
typedef enum { mul = 0, add } operand_type;
3333

34-
// function: vec = Operand(vec(scalar), vec(scalar)) (maybe with relu)
34+
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
3535
class VXXJitCode : public JitCode {
3636
public:
3737
const char* name() const override {
3838
std::string base = "VXXJitCode";
39+
if (scalar_index_ == 1) {
40+
base += "_Scalar";
41+
} else {
42+
base += "_Vec";
43+
}
3944
if (type_ == operand_type::mul) {
4045
base += "_Mul";
4146
} else if (type_ == operand_type::add) {
4247
base += "_Add";
4348
}
49+
if (scalar_index_ == 2) {
50+
base += "_Scalar";
51+
} else {
52+
base += "_Vec";
53+
}
4454
base += (with_relu_ ? "_Relu" : "");
4555
return base.c_str();
4656
}

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,15 @@ class VAddReluKernel : public Kernel {
8383
template <typename T>
8484
class VScalKernel : public Kernel {
8585
public:
86+
// y = a.*x
8687
void (*Compute)(const T *, const T *, T *, int);
8788
};
8889

8990
template <typename T>
9091
class VAddBiasKernel : public Kernel {
9192
public:
92-
virtual void Compute(const T a, const T *x, T *y) const = 0;
93+
// y = a.+x
94+
void (*Compute)(const T *, const T *, T *, int);
9395
};
9496

9597
template <typename T>

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 37 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ void VScalRefer(const T* a, const T* x, T* y, int n) {
6060
}
6161
}
6262

63+
template <typename T>
64+
void VAddBiasRefer(const T* a, const T* x, T* y, int n) {
65+
for (int i = 0; i < n; ++i) {
66+
y[i] = a[0] + x[i];
67+
}
68+
}
69+
6370
#ifdef PADDLE_WITH_MKLML
6471
template <typename T>
6572
void VMulMKL(const T* x, const T* y, T* z, int n);
@@ -300,62 +307,46 @@ bool VScalKernelImpl<double>::useMKL(int d) {
300307
}
301308
#endif
302309

303-
#undef DECLARE_STATIC_FUNC
304-
305-
REGISTER_JITKERNEL(vmul, VMulKernel);
306-
REGISTER_JITKERNEL(vadd, VAddKernel);
307-
REGISTER_JITKERNEL(vscal, VScalKernel);
308-
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
309-
310310
/* VAddBias JitKernel */
311-
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
311+
template <typename T>
312312
class VAddBiasKernelImpl : public VAddBiasKernel<T> {
313313
public:
314-
explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() { this->num_ = d; }
315-
void Compute(const T a, const T* x, T* y) const override {
316-
for (int i = 0; i < this->num_; ++i) {
317-
y[i] = x[i] + a;
314+
DECLARE_STATIC_FUNC;
315+
explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() {
316+
#ifdef PADDLE_WITH_XBYAK
317+
if (useJIT(d)) {
318+
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
319+
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false,
320+
sz > 4096 ? sz : 4096));
321+
this->Compute =
322+
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
323+
return;
318324
}
319-
}
320-
};
321-
322-
#define INTRI8_FLOAT(isa) \
323-
template <> \
324-
void VAddBiasKernelImpl<float, isa, kEQ8>::Compute( \
325-
const float a, const float* x, float* y) const { \
326-
__m256 tmp = _mm256_loadu_ps(x); \
327-
tmp = _mm256_add_ps(tmp, _mm256_set1_ps(a)); \
328-
_mm256_storeu_ps(y, tmp); \
329-
}
325+
#endif
330326

331-
#define INTRI16_FLOAT(isa) \
332-
template <> \
333-
void VAddBiasKernelImpl<float, isa, kEQ16>::Compute( \
334-
const float a, const float* x, float* y) const { \
335-
__m256 tmp0 = _mm256_loadu_ps(x); \
336-
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
337-
tmp0 = _mm256_add_ps(tmp0, _mm256_set1_ps(a)); \
338-
tmp1 = _mm256_add_ps(tmp1, _mm256_set1_ps(a)); \
339-
_mm256_storeu_ps(y, tmp0); \
340-
_mm256_storeu_ps(y + 8, tmp1); \
327+
this->Compute = VAddBiasRefer<T>;
341328
}
329+
#ifdef PADDLE_WITH_XBYAK
342330

343-
#ifdef __AVX__
344-
INTRI8_FLOAT(jit::avx);
345-
INTRI16_FLOAT(jit::avx);
346-
#endif
347-
#ifdef __AVX2__
348-
INTRI8_FLOAT(jit::avx2);
349-
INTRI16_FLOAT(jit::avx2);
331+
private:
332+
std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
350333
#endif
351-
#ifdef __AVX512F__
352-
INTRI8_FLOAT(jit::avx512f);
353-
INTRI16_FLOAT(jit::avx512f);
334+
};
335+
336+
#ifdef PADDLE_WITH_XBYAK
337+
template <>
338+
bool VAddBiasKernelImpl<float>::useJIT(int d) {
339+
return gen::VXXJitCode::init(d, 1);
340+
}
354341
#endif
355-
// TODO(TJ): eq16 test and complete avx512
356342

357-
#undef INTRI8_FLOAT
358-
#undef INTRI16_FLOAT
343+
#undef DECLARE_STATIC_FUNC
344+
345+
REGISTER_JITKERNEL(vmul, VMulKernel);
346+
REGISTER_JITKERNEL(vadd, VAddKernel);
347+
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
348+
REGISTER_JITKERNEL(vscal, VScalKernel);
349+
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
359350

360351
/* VRelu JitKernel */
361352
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
@@ -466,7 +457,6 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
466457
void Compute(const T* x, T* y) const override {}
467458
};
468459

469-
REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel);
470460
REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel);
471461
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);
472462

paddle/fluid/operators/math/jit_kernel_exp.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -409,11 +409,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
409409
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d);
410410
}
411411
void Compute(const T* x, T* y) const override {
412-
const T a = static_cast<T>(2);
412+
const T a = static_cast<T>(2), b = static_cast<T>(-1);
413413
vscal_->Compute(&a, x, y, this->num_);
414414
vsigmoid_->Compute(y, y);
415415
vscal_->Compute(&a, y, y, this->num_);
416-
vaddbias_->Compute(static_cast<T>(-1), y, y);
416+
vaddbias_->Compute(&b, y, y, this->num_);
417417
}
418418

419419
private:
@@ -473,11 +473,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
473473
_mm256_storeu_ps(y, tmp); \
474474
x += AVX_FLOAT_BLOCK; \
475475
y += AVX_FLOAT_BLOCK; \
476-
const float a = 2.f; \
476+
const float a = 2.f, b = -1.f; \
477477
vscal_->Compute(&a, x, y, this->num_); \
478478
vsigmoid_->Compute(y, y); \
479479
vscal_->Compute(&a, y, y, this->num_); \
480-
vaddbias_->Compute(-1.f, y, y); \
480+
vaddbias_->Compute(&b, y, y, this->num_); \
481481
}
482482

483483
#define INTRI_GT16_FLOAT(isa, expisa) \
@@ -504,11 +504,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
504504
} \
505505
x += this->end_; \
506506
y += this->end_; \
507-
const float a = 2.f; \
507+
const float a = 2.f, b = -1.f; \
508508
vscal_->Compute(&a, x, y, this->num_); \
509509
vsigmoid_->Compute(y, y); \
510510
vscal_->Compute(&a, y, y, this->num_); \
511-
vaddbias_->Compute(-1.f, y, y); \
511+
vaddbias_->Compute(&b, y, y, this->num_); \
512512
}
513513

514514
#ifndef __WIN32

paddle/fluid/operators/math/jit_kernel_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ TEST(JitKernel, vaddbias) {
128128
auto trefe = GetCurrentUS();
129129
auto ttgts = GetCurrentUS();
130130
for (int i = 0; i < repeat; ++i) {
131-
ker->Compute(a, x_data, ztgt_data);
131+
ker->Compute(&a, x_data, ztgt_data, d);
132132
}
133133
auto ttgte = GetCurrentUS();
134134

@@ -281,11 +281,11 @@ void vtanh_better(
281281
const paddle::operators::math::jitkernel::VAddBiasKernel<float>>&
282282
vaddbias,
283283
const int n, const float* x, float* y) {
284-
const float tmp1 = 2.f;
285-
vscal->Compute(&tmp1, x, y, n);
284+
const float a = 2.f, b = -1.f;
285+
vscal->Compute(&a, x, y, n);
286286
vsigmoid->Compute(y, y);
287-
vscal->Compute(&tmp1, y, y, n);
288-
vaddbias->Compute(-1.f, y, y);
287+
vscal->Compute(&a, y, y, n);
288+
vaddbias->Compute(&b, y, y, n);
289289
}
290290

291291
TEST(JitKernel, vtanh) {

0 commit comments

Comments
 (0)