Skip to content

Commit 382307b

Browse files
committed
refine code
test=develop
1 parent 25e070e commit 382307b

File tree

3 files changed

+65
-81
lines changed

3 files changed

+65
-81
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 22 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -24,59 +24,26 @@ namespace gen {
2424

2525
using namespace platform::jit; // NOLINT
2626

27-
bool VMulJitCode::init(int d) {
27+
bool VVVJitCode::init(int d) {
2828
// It's not necessary to use avx512 since it would slow down the frequency
2929
// and this kernel is not compute bound.
3030
return MayIUse(avx);
3131
}
3232

33-
void VMulJitCode::generate() {
33+
void VVVJitCode::generate() {
3434
// do not need push stack, and do not need save avx512reg if do not use avx512
35-
int offset = 0;
36-
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
37-
vmovups(ymm_src1, ptr[param1 + offset]);
38-
vmovups(ymm_src2, ptr[param2 + offset]);
39-
vmulps(ymm_dst, ymm_src1, ymm_src2);
40-
vmovups(ptr[param3 + offset], ymm_dst);
41-
offset += sizeof(float) * AVX_FLOAT_BLOCK;
42-
}
43-
int rest = num_ % AVX_FLOAT_BLOCK;
44-
if (rest >= 4) {
45-
vmovups(xmm_src1, ptr[param1 + offset]);
46-
vmovups(xmm_src2, ptr[param2 + offset]);
47-
vmulps(xmm_dst, xmm_src1, xmm_src2);
48-
vmovups(ptr[param3 + offset], xmm_dst);
49-
offset += sizeof(float) * 4;
50-
rest -= 4;
51-
}
52-
if (rest >= 2) {
53-
vmovq(xmm_src1, ptr[param1 + offset]);
54-
vmovq(xmm_src2, ptr[param2 + offset]);
55-
vmulps(xmm_dst, xmm_src1, xmm_src2);
56-
vmovq(ptr[param3 + offset], xmm_dst);
57-
offset += sizeof(float) * 2;
58-
rest -= 2;
59-
}
60-
if (rest > 0) {
61-
vmovss(xmm_src1, ptr[param1 + offset]);
62-
vmovss(xmm_src2, ptr[param2 + offset]);
63-
vmulss(xmm_dst, xmm_src1, xmm_src2);
64-
vmovss(ptr[param3 + offset], xmm_dst);
65-
}
66-
ret();
67-
}
68-
69-
bool VAddJitCode::init(int d) { return MayIUse(avx); }
70-
71-
void VAddJitCode::generate() {
7235
int offset = 0;
7336
if (with_relu_) {
7437
vxorps(ymm_zero, ymm_zero, ymm_zero);
7538
}
7639
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
7740
vmovups(ymm_src1, ptr[param1 + offset]);
7841
vmovups(ymm_src2, ptr[param2 + offset]);
79-
vaddps(ymm_dst, ymm_src1, ymm_src2);
42+
if (type_ == operand_type::mul) {
43+
vmulps(ymm_dst, ymm_src1, ymm_src2);
44+
} else if (type_ == operand_type::add) {
45+
vaddps(ymm_dst, ymm_src1, ymm_src2);
46+
}
8047
if (with_relu_) {
8148
vmaxps(ymm_dst, ymm_zero, ymm_dst);
8249
}
@@ -87,7 +54,11 @@ void VAddJitCode::generate() {
8754
if (rest >= 4) {
8855
vmovups(xmm_src1, ptr[param1 + offset]);
8956
vmovups(xmm_src2, ptr[param2 + offset]);
90-
vaddps(xmm_dst, xmm_src1, xmm_src2);
57+
if (type_ == operand_type::mul) {
58+
vmulps(xmm_dst, xmm_src1, xmm_src2);
59+
} else if (type_ == operand_type::add) {
60+
vaddps(xmm_dst, xmm_src1, xmm_src2);
61+
}
9162
if (with_relu_) {
9263
vmaxps(xmm_dst, xmm_zero, xmm_dst);
9364
}
@@ -98,7 +69,11 @@ void VAddJitCode::generate() {
9869
if (rest >= 2) {
9970
vmovq(xmm_src1, ptr[param1 + offset]);
10071
vmovq(xmm_src2, ptr[param2 + offset]);
101-
vaddps(xmm_dst, xmm_src1, xmm_src2);
72+
if (type_ == operand_type::mul) {
73+
vmulps(xmm_dst, xmm_src1, xmm_src2);
74+
} else if (type_ == operand_type::add) {
75+
vaddps(xmm_dst, xmm_src1, xmm_src2);
76+
}
10277
if (with_relu_) {
10378
vmaxps(xmm_dst, xmm_zero, xmm_dst);
10479
}
@@ -109,7 +84,11 @@ void VAddJitCode::generate() {
10984
if (rest > 0) {
11085
vmovss(xmm_src1, ptr[param1 + offset]);
11186
vmovss(xmm_src2, ptr[param2 + offset]);
112-
vaddss(xmm_dst, xmm_src1, xmm_src2);
87+
if (type_ == operand_type::mul) {
88+
vmulss(xmm_dst, xmm_src1, xmm_src2);
89+
} else if (type_ == operand_type::add) {
90+
vaddss(xmm_dst, xmm_src1, xmm_src2);
91+
}
11392
if (with_relu_) {
11493
vmaxps(xmm_dst, xmm_zero, xmm_dst);
11594
}

paddle/fluid/operators/math/jit_code.h

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <string>
1718
#include "paddle/fluid/operators/math/jit_gen.h"
18-
1919
namespace paddle {
2020
namespace operators {
2121
namespace math {
@@ -29,41 +29,33 @@ using ymm_t = const Xbyak::Ymm;
2929
using zmm_t = const Xbyak::Zmm;
3030
using Label = Xbyak::Label;
3131

32-
class VMulJitCode : public JitCode {
33-
public:
34-
DECLARE_JIT_CODE(VMulJitCode);
35-
explicit VMulJitCode(int d, size_t code_size = 256 * 1024,
36-
void* code_ptr = nullptr)
37-
: JitCode(code_size, code_ptr), num_(d) {}
38-
static bool init(int d);
39-
void generate() override;
40-
41-
private:
42-
int num_;
43-
reg64_t param1{abi_param1};
44-
reg64_t param2{abi_param2};
45-
reg64_t param3{abi_param3};
46-
47-
xmm_t xmm_src1 = xmm_t(0);
48-
xmm_t xmm_src2 = xmm_t(1);
49-
xmm_t xmm_dst = xmm_t(1);
50-
51-
ymm_t ymm_src1 = ymm_t(0);
52-
ymm_t ymm_src2 = ymm_t(1);
53-
ymm_t ymm_dst = ymm_t(1);
54-
};
32+
// function: vec = Operand(vec, vec) (maybe with relu)
33+
typedef enum { mul = 0, add } operand_type;
5534

56-
class VAddJitCode : public JitCode {
35+
class VVVJitCode : public JitCode {
5736
public:
58-
DECLARE_JIT_CODE(VAddJitCode);
59-
explicit VAddJitCode(int d, bool with_relu, size_t code_size = 256 * 1024,
60-
void* code_ptr = nullptr)
61-
: JitCode(code_size, code_ptr), num_(d), with_relu_(with_relu) {}
37+
const char* name() const override {
38+
std::string base = "VVVJitCode";
39+
if (type_ == operand_type::mul) {
40+
base += "_Mul";
41+
} else if (type_ == operand_type::add) {
42+
base += "_Add";
43+
}
44+
base += (with_relu_ ? "_relu" : "");
45+
return base.c_str();
46+
}
47+
explicit VVVJitCode(int d, operand_type type, bool with_relu,
48+
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
49+
: JitCode(code_size, code_ptr),
50+
num_(d),
51+
type_(type),
52+
with_relu_(with_relu) {}
6253
static bool init(int d);
6354
void generate() override;
6455

6556
private:
6657
int num_;
58+
operand_type type_;
6759
bool with_relu_;
6860
reg64_t param1{abi_param1};
6961
reg64_t param2{abi_param2};

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ class VMulKernelImpl : public VMulKernel<T> {
102102
if (useJIT(d)) {
103103
// roughly estimate the size of code
104104
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
105-
jitcode_.reset(new gen::VMulJitCode(d, sz > 4096 ? sz : 4096));
105+
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::mul, false,
106+
sz > 4096 ? sz : 4096));
106107
this->Compute =
107108
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
108109
return;
@@ -120,14 +121,14 @@ class VMulKernelImpl : public VMulKernel<T> {
120121
#ifdef PADDLE_WITH_XBYAK
121122

122123
private:
123-
std::unique_ptr<gen::VMulJitCode> jitcode_{nullptr};
124+
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr};
124125
#endif
125126
};
126127

127128
#ifdef PADDLE_WITH_XBYAK
128129
template <>
129130
bool VMulKernelImpl<float>::useJIT(int d) {
130-
return gen::VMulJitCode::init(d);
131+
return gen::VVVJitCode::init(d);
131132
}
132133
#endif
133134

@@ -149,13 +150,16 @@ class VAddKernelImpl : public VAddKernel<T> {
149150
public:
150151
DECLARE_STATIC_FUNC;
151152
explicit VAddKernelImpl(int d) : VAddKernel<T>() {
153+
#ifdef PADDLE_WITH_XBYAK
152154
if (useJIT(d)) {
153155
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
154-
jitcode_.reset(new gen::VAddJitCode(d, false, sz > 4096 ? sz : 4096));
156+
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, false,
157+
sz > 4096 ? sz : 4096));
155158
this->Compute =
156159
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
157160
return;
158161
}
162+
#endif
159163
#ifdef PADDLE_WITH_MKLML
160164
if (useMKL(d)) {
161165
this->Compute = VAddMKL<T>;
@@ -166,14 +170,17 @@ class VAddKernelImpl : public VAddKernel<T> {
166170
}
167171

168172
private:
169-
std::unique_ptr<gen::VAddJitCode> jitcode_{nullptr};
173+
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr};
170174
};
171175

176+
#ifdef PADDLE_WITH_XBYAK
172177
template <>
173178
bool VAddKernelImpl<float>::useJIT(int d) {
174-
return gen::VAddJitCode::init(d);
179+
return gen::VVVJitCode::init(d);
175180
}
181+
#endif
176182

183+
#ifdef PADDLE_WITH_MKLML
177184
template <>
178185
bool VAddKernelImpl<float>::useMKL(int d) {
179186
return d > 512;
@@ -183,31 +190,37 @@ template <>
183190
bool VAddKernelImpl<double>::useMKL(int d) {
184191
return true;
185192
}
193+
#endif
186194

187195
/* VAddRelu JitKernel */
188196
template <typename T>
189197
class VAddReluKernelImpl : public VAddReluKernel<T> {
190198
public:
191199
DECLARE_STATIC_FUNC;
192200
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() {
201+
#ifdef PADDLE_WITH_XBYAK
193202
if (useJIT(d)) {
194203
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
195-
jitcode_.reset(new gen::VAddJitCode(d, true, sz > 4096 ? sz : 4096));
204+
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, true,
205+
sz > 4096 ? sz : 4096));
196206
this->Compute =
197207
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
198208
return;
199209
}
210+
#endif
200211
this->Compute = VAddReluRefer<T>;
201212
}
202213

203214
private:
204-
std::unique_ptr<gen::VAddJitCode> jitcode_{nullptr};
215+
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr};
205216
};
206217

218+
#ifdef PADDLE_WITH_XBYAK
207219
template <>
208220
bool VAddReluKernelImpl<float>::useJIT(int d) {
209-
return gen::VAddJitCode::init(d);
221+
return gen::VVVJitCode::init(d);
210222
}
223+
#endif
211224

212225
#undef DECLARE_STATIC_FUNC
213226

0 commit comments

Comments
 (0)