Skip to content

Commit e8642c3

Browse files
authored
Merge pull request #14265 from tensor-tang/fea/jit/vadd
add vadd, vaddrelu jitcode
2 parents 3d5a990 + 382307b commit e8642c3

File tree

7 files changed

+195
-177
lines changed

7 files changed

+195
-177
lines changed

paddle/fluid/operators/math/fc_compute.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
3636
.template Get<jitkernel::VAddReluKernel<T>>(N);
3737
for (int i = 0; i < M; i++) {
3838
T* dst = Y + i * N;
39-
vaddrelu->Compute(B, dst, dst);
39+
vaddrelu->Compute(B, dst, dst, N);
4040
}
4141
} else {
4242
const auto& vadd = jitkernel::KernelPool::Instance()
@@ -47,7 +47,7 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
4747
#endif
4848
for (int i = 0; i < M; i++) {
4949
T* dst = Y + i * N;
50-
vadd->Compute(B, dst, dst);
50+
vadd->Compute(B, dst, dst, N);
5151
}
5252
}
5353
}

paddle/fluid/operators/math/jit_code.cc

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,48 +24,78 @@ namespace gen {
2424

2525
using namespace platform::jit; // NOLINT
2626

27-
bool VMulJitCode::init(int d) {
27+
bool VVVJitCode::init(int d) {
2828
// It's not necessary to use avx512 since it would slow down the frequency
2929
// and this kernel is not compute bound.
3030
return MayIUse(avx);
3131
}
3232

33-
void VMulJitCode::generate() {
33+
void VVVJitCode::generate() {
3434
// do not need push stack, and do not need save avx512reg if do not use avx512
3535
int offset = 0;
36+
if (with_relu_) {
37+
vxorps(ymm_zero, ymm_zero, ymm_zero);
38+
}
3639
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
3740
vmovups(ymm_src1, ptr[param1 + offset]);
3841
vmovups(ymm_src2, ptr[param2 + offset]);
39-
vmulps(ymm_dst, ymm_src1, ymm_src2);
42+
if (type_ == operand_type::mul) {
43+
vmulps(ymm_dst, ymm_src1, ymm_src2);
44+
} else if (type_ == operand_type::add) {
45+
vaddps(ymm_dst, ymm_src1, ymm_src2);
46+
}
47+
if (with_relu_) {
48+
vmaxps(ymm_dst, ymm_zero, ymm_dst);
49+
}
4050
vmovups(ptr[param3 + offset], ymm_dst);
4151
offset += sizeof(float) * AVX_FLOAT_BLOCK;
4252
}
4353
int rest = num_ % AVX_FLOAT_BLOCK;
4454
if (rest >= 4) {
4555
vmovups(xmm_src1, ptr[param1 + offset]);
4656
vmovups(xmm_src2, ptr[param2 + offset]);
47-
vmulps(xmm_dst, xmm_src1, xmm_src2);
57+
if (type_ == operand_type::mul) {
58+
vmulps(xmm_dst, xmm_src1, xmm_src2);
59+
} else if (type_ == operand_type::add) {
60+
vaddps(xmm_dst, xmm_src1, xmm_src2);
61+
}
62+
if (with_relu_) {
63+
vmaxps(xmm_dst, xmm_zero, xmm_dst);
64+
}
4865
vmovups(ptr[param3 + offset], xmm_dst);
4966
offset += sizeof(float) * 4;
5067
rest -= 4;
5168
}
5269
if (rest >= 2) {
5370
vmovq(xmm_src1, ptr[param1 + offset]);
5471
vmovq(xmm_src2, ptr[param2 + offset]);
55-
vmulps(xmm_dst, xmm_src1, xmm_src2);
72+
if (type_ == operand_type::mul) {
73+
vmulps(xmm_dst, xmm_src1, xmm_src2);
74+
} else if (type_ == operand_type::add) {
75+
vaddps(xmm_dst, xmm_src1, xmm_src2);
76+
}
77+
if (with_relu_) {
78+
vmaxps(xmm_dst, xmm_zero, xmm_dst);
79+
}
5680
vmovq(ptr[param3 + offset], xmm_dst);
5781
offset += sizeof(float) * 2;
5882
rest -= 2;
5983
}
6084
if (rest > 0) {
6185
vmovss(xmm_src1, ptr[param1 + offset]);
6286
vmovss(xmm_src2, ptr[param2 + offset]);
63-
vmulss(xmm_dst, xmm_src1, xmm_src2);
87+
if (type_ == operand_type::mul) {
88+
vmulss(xmm_dst, xmm_src1, xmm_src2);
89+
} else if (type_ == operand_type::add) {
90+
vaddss(xmm_dst, xmm_src1, xmm_src2);
91+
}
92+
if (with_relu_) {
93+
vmaxps(xmm_dst, xmm_zero, xmm_dst);
94+
}
6495
vmovss(ptr[param3 + offset], xmm_dst);
6596
}
6697
ret();
6798
}
68-
6999
} // namespace gen
70100
} // namespace jitkernel
71101
} // namespace math

paddle/fluid/operators/math/jit_code.h

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <string>
1718
#include "paddle/fluid/operators/math/jit_gen.h"
18-
1919
namespace paddle {
2020
namespace operators {
2121
namespace math {
@@ -29,28 +29,47 @@ using ymm_t = const Xbyak::Ymm;
2929
using zmm_t = const Xbyak::Zmm;
3030
using Label = Xbyak::Label;
3131

32-
class VMulJitCode : public JitCode {
32+
// function: vec = Operand(vec, vec) (maybe with relu)
33+
typedef enum { mul = 0, add } operand_type;
34+
35+
class VVVJitCode : public JitCode {
3336
public:
34-
DECLARE_JIT_CODE(VMulJitCode);
35-
explicit VMulJitCode(int d, size_t code_size = 256 * 1024,
36-
void* code_ptr = nullptr)
37-
: JitCode(code_size, code_ptr), num_(d) {}
37+
const char* name() const override {
38+
std::string base = "VVVJitCode";
39+
if (type_ == operand_type::mul) {
40+
base += "_Mul";
41+
} else if (type_ == operand_type::add) {
42+
base += "_Add";
43+
}
44+
base += (with_relu_ ? "_relu" : "");
45+
return base.c_str();
46+
}
47+
explicit VVVJitCode(int d, operand_type type, bool with_relu,
48+
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
49+
: JitCode(code_size, code_ptr),
50+
num_(d),
51+
type_(type),
52+
with_relu_(with_relu) {}
3853
static bool init(int d);
3954
void generate() override;
4055

4156
private:
4257
int num_;
58+
operand_type type_;
59+
bool with_relu_;
4360
reg64_t param1{abi_param1};
4461
reg64_t param2{abi_param2};
4562
reg64_t param3{abi_param3};
4663

4764
xmm_t xmm_src1 = xmm_t(0);
4865
xmm_t xmm_src2 = xmm_t(1);
49-
xmm_t xmm_dst = xmm_t(2);
66+
xmm_t xmm_dst = xmm_t(1);
67+
xmm_t xmm_zero = xmm_t(2);
5068

5169
ymm_t ymm_src1 = ymm_t(0);
5270
ymm_t ymm_src2 = ymm_t(1);
53-
ymm_t ymm_dst = ymm_t(2);
71+
ymm_t ymm_dst = ymm_t(1);
72+
ymm_t ymm_zero = ymm_t(2);
5473
};
5574

5675
} // namespace gen

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,26 +71,26 @@ class VMulKernel : public Kernel {
7171
template <typename T>
7272
class VAddKernel : public Kernel {
7373
public:
74-
virtual void Compute(const T *x, const T *y, T *z) const = 0;
74+
void (*Compute)(const T *, const T *, T *, int);
7575
};
7676

7777
template <typename T>
78-
class VScalKernel : public Kernel {
78+
class VAddReluKernel : public Kernel {
7979
public:
80-
virtual void Compute(const T a, const T *x, T *y) const = 0;
81-
virtual void Compute(const T a, T *x) const = 0;
80+
void (*Compute)(const T *, const T *, T *, int);
8281
};
8382

8483
template <typename T>
85-
class VAddBiasKernel : public Kernel {
84+
class VScalKernel : public Kernel {
8685
public:
8786
virtual void Compute(const T a, const T *x, T *y) const = 0;
87+
virtual void Compute(const T a, T *x) const = 0;
8888
};
8989

9090
template <typename T>
91-
class VAddReluKernel : public Kernel {
91+
class VAddBiasKernel : public Kernel {
9292
public:
93-
virtual void Compute(const T *x, const T *y, T *z) const = 0;
93+
virtual void Compute(const T a, const T *x, T *y) const = 0;
9494
};
9595

9696
template <typename T>

0 commit comments

Comments
 (0)