Skip to content

Commit 3d950a8

Browse files
committed
combine jitcode of vscal
1 parent 03e11f3 commit 3d950a8

File tree

3 files changed

+58
-93
lines changed

3 files changed

+58
-93
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,30 @@ namespace gen {
2424

2525
using namespace platform::jit; // NOLINT
2626

27-
bool VVVJitCode::init(int d) {
27+
bool VXXJitCode::init(int d, int scalar_index) {
2828
// It's not necessary to use avx512 since it would slow down the frequency
2929
// and this kernel is not compute bound.
30-
return MayIUse(avx);
30+
return MayIUse(avx) && scalar_index >= 0 && scalar_index <= 2;
3131
}
3232

33-
void VVVJitCode::generate() {
33+
void VXXJitCode::generate() {
3434
// do not need push stack, and do not need save avx512reg if do not use avx512
3535
int offset = 0;
3636
if (with_relu_) {
3737
vxorps(ymm_zero, ymm_zero, ymm_zero);
3838
}
39+
if (scalar_index_ == 1) {
40+
vbroadcastss(ymm_src1, ptr[param1]);
41+
} else if (scalar_index_ == 2) {
42+
vbroadcastss(ymm_src2, ptr[param2]);
43+
}
3944
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
40-
vmovups(ymm_src1, ptr[param1 + offset]);
41-
vmovups(ymm_src2, ptr[param2 + offset]);
45+
if (scalar_index_ != 1) {
46+
vmovups(ymm_src1, ptr[param1 + offset]);
47+
}
48+
if (scalar_index_ != 2) {
49+
vmovups(ymm_src2, ptr[param2 + offset]);
50+
}
4251
if (type_ == operand_type::mul) {
4352
vmulps(ymm_dst, ymm_src1, ymm_src2);
4453
} else if (type_ == operand_type::add) {
@@ -52,8 +61,12 @@ void VVVJitCode::generate() {
5261
}
5362
int rest = num_ % AVX_FLOAT_BLOCK;
5463
if (rest >= 4) {
55-
vmovups(xmm_src1, ptr[param1 + offset]);
56-
vmovups(xmm_src2, ptr[param2 + offset]);
64+
if (scalar_index_ != 1) {
65+
vmovups(xmm_src1, ptr[param1 + offset]);
66+
}
67+
if (scalar_index_ != 2) {
68+
vmovups(xmm_src2, ptr[param2 + offset]);
69+
}
5770
if (type_ == operand_type::mul) {
5871
vmulps(xmm_dst, xmm_src1, xmm_src2);
5972
} else if (type_ == operand_type::add) {
@@ -67,8 +80,12 @@ void VVVJitCode::generate() {
6780
rest -= 4;
6881
}
6982
if (rest >= 2) {
70-
vmovq(xmm_src1, ptr[param1 + offset]);
71-
vmovq(xmm_src2, ptr[param2 + offset]);
83+
if (scalar_index_ != 1) {
84+
vmovups(xmm_src1, ptr[param1 + offset]);
85+
}
86+
if (scalar_index_ != 2) {
87+
vmovups(xmm_src2, ptr[param2 + offset]);
88+
}
7289
if (type_ == operand_type::mul) {
7390
vmulps(xmm_dst, xmm_src1, xmm_src2);
7491
} else if (type_ == operand_type::add) {
@@ -82,8 +99,12 @@ void VVVJitCode::generate() {
8299
rest -= 2;
83100
}
84101
if (rest > 0) {
85-
vmovss(xmm_src1, ptr[param1 + offset]);
86-
vmovss(xmm_src2, ptr[param2 + offset]);
102+
if (scalar_index_ != 1) {
103+
vmovups(xmm_src1, ptr[param1 + offset]);
104+
}
105+
if (scalar_index_ != 2) {
106+
vmovups(xmm_src2, ptr[param2 + offset]);
107+
}
87108
if (type_ == operand_type::mul) {
88109
vmulss(xmm_dst, xmm_src1, xmm_src2);
89110
} else if (type_ == operand_type::add) {
@@ -97,40 +118,6 @@ void VVVJitCode::generate() {
97118
ret();
98119
}
99120

100-
bool VScalJitCode::init(int d) { return MayIUse(avx); }
101-
102-
void VScalJitCode::generate() {
103-
int offset = 0;
104-
vbroadcastss(ymm_src1, ptr[param1]);
105-
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
106-
vmovups(ymm_src2, ptr[param2 + offset]);
107-
vmulps(ymm_dst, ymm_src1, ymm_src2);
108-
vmovups(ptr[param3 + offset], ymm_dst);
109-
offset += sizeof(float) * AVX_FLOAT_BLOCK;
110-
}
111-
int rest = num_ % AVX_FLOAT_BLOCK;
112-
if (rest >= 4) {
113-
vmovups(xmm_src2, ptr[param2 + offset]);
114-
vmulps(xmm_dst, xmm_src1, xmm_src2);
115-
vmovups(ptr[param3 + offset], xmm_dst);
116-
offset += sizeof(float) * 4;
117-
rest -= 4;
118-
}
119-
if (rest >= 2) {
120-
vmovq(xmm_src2, ptr[param2 + offset]);
121-
vmulps(xmm_dst, xmm_src1, xmm_src2);
122-
vmovq(ptr[param3 + offset], xmm_dst);
123-
offset += sizeof(float) * 2;
124-
rest -= 2;
125-
}
126-
if (rest > 0) {
127-
vmovss(xmm_src2, ptr[param2 + offset]);
128-
vmulss(xmm_dst, xmm_src1, xmm_src2);
129-
vmovss(ptr[param3 + offset], xmm_dst);
130-
}
131-
ret();
132-
}
133-
134121
} // namespace gen
135122
} // namespace jitkernel
136123
} // namespace math

paddle/fluid/operators/math/jit_code.h

Lines changed: 13 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ using Label = Xbyak::Label;
3131

3232
typedef enum { mul = 0, add } operand_type;
3333

34-
// function: vec = Operand(vec, vec) (maybe with relu)
35-
class VVVJitCode : public JitCode {
34+
// function: vec = Operand(vec(scalar), vec(scalar)) (maybe with relu)
35+
class VXXJitCode : public JitCode {
3636
public:
3737
const char* name() const override {
38-
std::string base = "VVVJitCode";
38+
std::string base = "VXXJitCode";
3939
if (type_ == operand_type::mul) {
4040
base += "_Mul";
4141
} else if (type_ == operand_type::add) {
@@ -44,58 +44,35 @@ class VVVJitCode : public JitCode {
4444
base += (with_relu_ ? "_Relu" : "");
4545
return base.c_str();
4646
}
47-
explicit VVVJitCode(int d, operand_type type, bool with_relu,
48-
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
47+
explicit VXXJitCode(int d, operand_type type, int scalar_index,
48+
bool with_relu, size_t code_size = 256 * 1024,
49+
void* code_ptr = nullptr)
4950
: JitCode(code_size, code_ptr),
5051
num_(d),
5152
type_(type),
53+
scalar_index_(scalar_index),
5254
with_relu_(with_relu) {}
53-
static bool init(int d);
55+
static bool init(int d, int scalar_index = 0);
5456
void generate() override;
5557

5658
private:
5759
int num_;
5860
operand_type type_;
61+
int scalar_index_;
5962
bool with_relu_;
6063
reg64_t param1{abi_param1};
6164
reg64_t param2{abi_param2};
6265
reg64_t param3{abi_param3};
6366

6467
xmm_t xmm_src1 = xmm_t(0);
6568
xmm_t xmm_src2 = xmm_t(1);
66-
xmm_t xmm_dst = xmm_t(1);
67-
xmm_t xmm_zero = xmm_t(2);
69+
xmm_t xmm_dst = xmm_t(2);
70+
xmm_t xmm_zero = xmm_t(3);
6871

6972
ymm_t ymm_src1 = ymm_t(0);
7073
ymm_t ymm_src2 = ymm_t(1);
71-
ymm_t ymm_dst = ymm_t(1);
72-
ymm_t ymm_zero = ymm_t(2);
73-
};
74-
75-
class VScalJitCode : public JitCode {
76-
public:
77-
DECLARE_JIT_CODE(VScalJitCode);
78-
explicit VScalJitCode(int d, size_t code_size = 256 * 1024,
79-
void* code_ptr = nullptr)
80-
: JitCode(code_size, code_ptr), num_(d) {}
81-
static bool init(int d);
82-
void generate() override;
83-
84-
private:
85-
int num_;
86-
reg64_t param1{abi_param1};
87-
reg64_t param2{abi_param2};
88-
reg64_t param3{abi_param3};
89-
90-
xmm_t xmm_src1 = xmm_t(0);
91-
xmm_t xmm_src2 = xmm_t(1);
92-
xmm_t xmm_dst = xmm_t(1);
93-
xmm_t xmm_zero = xmm_t(2);
94-
95-
ymm_t ymm_src1 = ymm_t(0);
96-
ymm_t ymm_src2 = ymm_t(1);
97-
ymm_t ymm_dst = ymm_t(1);
98-
ymm_t ymm_zero = ymm_t(2);
74+
ymm_t ymm_dst = ymm_t(2);
75+
ymm_t ymm_zero = ymm_t(3);
9976
};
10077

10178
} // namespace gen

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class VMulKernelImpl : public VMulKernel<T> {
131131
if (useJIT(d)) {
132132
// roughly estimate the size of code
133133
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
134-
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::mul, false,
134+
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false,
135135
sz > 4096 ? sz : 4096));
136136
this->Compute =
137137
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
@@ -150,14 +150,14 @@ class VMulKernelImpl : public VMulKernel<T> {
150150
#ifdef PADDLE_WITH_XBYAK
151151

152152
private:
153-
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr};
153+
std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
154154
#endif
155155
};
156156

157157
#ifdef PADDLE_WITH_XBYAK
158158
template <>
159159
bool VMulKernelImpl<float>::useJIT(int d) {
160-
return gen::VVVJitCode::init(d);
160+
return gen::VXXJitCode::init(d);
161161
}
162162
#endif
163163

@@ -182,7 +182,7 @@ class VAddKernelImpl : public VAddKernel<T> {
182182
#ifdef PADDLE_WITH_XBYAK
183183
if (useJIT(d)) {
184184
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
185-
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, false,
185+
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false,
186186
sz > 4096 ? sz : 4096));
187187
this->Compute =
188188
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
@@ -200,14 +200,14 @@ class VAddKernelImpl : public VAddKernel<T> {
200200
#ifdef PADDLE_WITH_XBYAK
201201

202202
private:
203-
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr};
203+
std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
204204
#endif
205205
};
206206

207207
#ifdef PADDLE_WITH_XBYAK
208208
template <>
209209
bool VAddKernelImpl<float>::useJIT(int d) {
210-
return gen::VVVJitCode::init(d);
210+
return gen::VXXJitCode::init(d);
211211
}
212212
#endif
213213

@@ -232,7 +232,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
232232
#ifdef PADDLE_WITH_XBYAK
233233
if (useJIT(d)) {
234234
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
235-
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, true,
235+
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true,
236236
sz > 4096 ? sz : 4096));
237237
this->Compute =
238238
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
@@ -244,14 +244,14 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
244244
#ifdef PADDLE_WITH_XBYAK
245245

246246
private:
247-
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr};
247+
std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
248248
#endif
249249
};
250250

251251
#ifdef PADDLE_WITH_XBYAK
252252
template <>
253253
bool VAddReluKernelImpl<float>::useJIT(int d) {
254-
return gen::VVVJitCode::init(d);
254+
return gen::VXXJitCode::init(d);
255255
}
256256
#endif
257257

@@ -264,7 +264,8 @@ class VScalKernelImpl : public VScalKernel<T> {
264264
#ifdef PADDLE_WITH_XBYAK
265265
if (useJIT(d)) {
266266
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
267-
jitcode_.reset(new gen::VScalJitCode(d, sz > 4096 ? sz : 4096));
267+
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false,
268+
sz > 4096 ? sz : 4096));
268269
this->Compute =
269270
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
270271
return;
@@ -281,14 +282,14 @@ class VScalKernelImpl : public VScalKernel<T> {
281282
#ifdef PADDLE_WITH_XBYAK
282283

283284
private:
284-
std::unique_ptr<gen::VScalJitCode> jitcode_{nullptr};
285+
std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
285286
#endif
286287
};
287288

288289
#ifdef PADDLE_WITH_XBYAK
289290
template <>
290291
bool VScalKernelImpl<float>::useJIT(int d) {
291-
return gen::VScalJitCode::init(d);
292+
return gen::VXXJitCode::init(d, 1);
292293
}
293294
#endif
294295

0 commit comments

Comments
 (0)