Skip to content

Commit 22125eb

Browse files
authored
Merge pull request #14321 from tensor-tang/fea/jit/vscal
Fea jitcode vscal vaddbias
2 parents f1046d7 + 5e64244 commit 22125eb

File tree

6 files changed

+193
-161
lines changed

6 files changed

+193
-161
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,30 @@ namespace gen {
2424

2525
using namespace platform::jit; // NOLINT
2626

27-
bool VVVJitCode::init(int d) {
27+
bool VXXJitCode::init(int d, int scalar_index) {
2828
// It's not necessary to use avx512 since it would slow down the frequency
2929
// and this kernel is not compute bound.
30-
return MayIUse(avx);
30+
return MayIUse(avx) && scalar_index >= 0 && scalar_index <= 2;
3131
}
3232

33-
void VVVJitCode::generate() {
33+
void VXXJitCode::generate() {
3434
// do not need push stack, and do not need save avx512reg if do not use avx512
3535
int offset = 0;
3636
if (with_relu_) {
3737
vxorps(ymm_zero, ymm_zero, ymm_zero);
3838
}
39+
if (scalar_index_ == 1) {
40+
vbroadcastss(ymm_src1, ptr[param1]);
41+
} else if (scalar_index_ == 2) {
42+
vbroadcastss(ymm_src2, ptr[param2]);
43+
}
3944
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
40-
vmovups(ymm_src1, ptr[param1 + offset]);
41-
vmovups(ymm_src2, ptr[param2 + offset]);
45+
if (scalar_index_ != 1) {
46+
vmovups(ymm_src1, ptr[param1 + offset]);
47+
}
48+
if (scalar_index_ != 2) {
49+
vmovups(ymm_src2, ptr[param2 + offset]);
50+
}
4251
if (type_ == operand_type::mul) {
4352
vmulps(ymm_dst, ymm_src1, ymm_src2);
4453
} else if (type_ == operand_type::add) {
@@ -52,8 +61,12 @@ void VVVJitCode::generate() {
5261
}
5362
int rest = num_ % AVX_FLOAT_BLOCK;
5463
if (rest >= 4) {
55-
vmovups(xmm_src1, ptr[param1 + offset]);
56-
vmovups(xmm_src2, ptr[param2 + offset]);
64+
if (scalar_index_ != 1) {
65+
vmovups(xmm_src1, ptr[param1 + offset]);
66+
}
67+
if (scalar_index_ != 2) {
68+
vmovups(xmm_src2, ptr[param2 + offset]);
69+
}
5770
if (type_ == operand_type::mul) {
5871
vmulps(xmm_dst, xmm_src1, xmm_src2);
5972
} else if (type_ == operand_type::add) {
@@ -67,8 +80,12 @@ void VVVJitCode::generate() {
6780
rest -= 4;
6881
}
6982
if (rest >= 2) {
70-
vmovq(xmm_src1, ptr[param1 + offset]);
71-
vmovq(xmm_src2, ptr[param2 + offset]);
83+
if (scalar_index_ != 1) {
84+
vmovups(xmm_src1, ptr[param1 + offset]);
85+
}
86+
if (scalar_index_ != 2) {
87+
vmovups(xmm_src2, ptr[param2 + offset]);
88+
}
7289
if (type_ == operand_type::mul) {
7390
vmulps(xmm_dst, xmm_src1, xmm_src2);
7491
} else if (type_ == operand_type::add) {
@@ -82,8 +99,12 @@ void VVVJitCode::generate() {
8299
rest -= 2;
83100
}
84101
if (rest > 0) {
85-
vmovss(xmm_src1, ptr[param1 + offset]);
86-
vmovss(xmm_src2, ptr[param2 + offset]);
102+
if (scalar_index_ != 1) {
103+
vmovups(xmm_src1, ptr[param1 + offset]);
104+
}
105+
if (scalar_index_ != 2) {
106+
vmovups(xmm_src2, ptr[param2 + offset]);
107+
}
87108
if (type_ == operand_type::mul) {
88109
vmulss(xmm_dst, xmm_src1, xmm_src2);
89110
} else if (type_ == operand_type::add) {
@@ -96,6 +117,7 @@ void VVVJitCode::generate() {
96117
}
97118
ret();
98119
}
120+
99121
} // namespace gen
100122
} // namespace jitkernel
101123
} // namespace math

paddle/fluid/operators/math/jit_code.h

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,47 +29,60 @@ using ymm_t = const Xbyak::Ymm;
2929
using zmm_t = const Xbyak::Zmm;
3030
using Label = Xbyak::Label;
3131

32-
// function: vec = Operand(vec, vec) (maybe with relu)
3332
typedef enum { mul = 0, add } operand_type;
3433

35-
class VVVJitCode : public JitCode {
34+
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
35+
class VXXJitCode : public JitCode {
3636
public:
3737
const char* name() const override {
38-
std::string base = "VVVJitCode";
38+
std::string base = "VXXJitCode";
39+
if (scalar_index_ == 1) {
40+
base += "_Scalar";
41+
} else {
42+
base += "_Vec";
43+
}
3944
if (type_ == operand_type::mul) {
4045
base += "_Mul";
4146
} else if (type_ == operand_type::add) {
4247
base += "_Add";
4348
}
44-
base += (with_relu_ ? "_relu" : "");
49+
if (scalar_index_ == 2) {
50+
base += "_Scalar";
51+
} else {
52+
base += "_Vec";
53+
}
54+
base += (with_relu_ ? "_Relu" : "");
4555
return base.c_str();
4656
}
47-
explicit VVVJitCode(int d, operand_type type, bool with_relu,
48-
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
57+
explicit VXXJitCode(int d, operand_type type, int scalar_index,
58+
bool with_relu, size_t code_size = 256 * 1024,
59+
void* code_ptr = nullptr)
4960
: JitCode(code_size, code_ptr),
5061
num_(d),
5162
type_(type),
63+
scalar_index_(scalar_index),
5264
with_relu_(with_relu) {}
53-
static bool init(int d);
65+
static bool init(int d, int scalar_index = 0);
5466
void generate() override;
5567

5668
private:
5769
int num_;
5870
operand_type type_;
71+
int scalar_index_;
5972
bool with_relu_;
6073
reg64_t param1{abi_param1};
6174
reg64_t param2{abi_param2};
6275
reg64_t param3{abi_param3};
6376

6477
xmm_t xmm_src1 = xmm_t(0);
6578
xmm_t xmm_src2 = xmm_t(1);
66-
xmm_t xmm_dst = xmm_t(1);
67-
xmm_t xmm_zero = xmm_t(2);
79+
xmm_t xmm_dst = xmm_t(2);
80+
xmm_t xmm_zero = xmm_t(3);
6881

6982
ymm_t ymm_src1 = ymm_t(0);
7083
ymm_t ymm_src2 = ymm_t(1);
71-
ymm_t ymm_dst = ymm_t(1);
72-
ymm_t ymm_zero = ymm_t(2);
84+
ymm_t ymm_dst = ymm_t(2);
85+
ymm_t ymm_zero = ymm_t(3);
7386
};
7487

7588
} // namespace gen

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,15 @@ class VAddReluKernel : public Kernel {
8383
template <typename T>
8484
class VScalKernel : public Kernel {
8585
public:
86-
virtual void Compute(const T a, const T *x, T *y) const = 0;
87-
virtual void Compute(const T a, T *x) const = 0;
86+
// y = a.*x
87+
void (*Compute)(const T *, const T *, T *, int);
8888
};
8989

9090
template <typename T>
9191
class VAddBiasKernel : public Kernel {
9292
public:
93-
virtual void Compute(const T a, const T *x, T *y) const = 0;
93+
// y = a.+x
94+
void (*Compute)(const T *, const T *, T *, int);
9495
};
9596

9697
template <typename T>

0 commit comments

Comments
 (0)