Skip to content

Commit f65ddff

Browse files
committed
unify act jitcode of relu, exp, sigmoid and tanh
1 parent 6a15907 commit f65ddff

File tree

4 files changed

+153
-159
lines changed

4 files changed

+153
-159
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 83 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -118,40 +118,6 @@ void VXXJitCode::generate() {
118118
ret();
119119
}
120120

121-
bool ReluJitCode::init(int d) { return MayIUse(avx); }
122-
123-
void ReluJitCode::generate() {
124-
int offset = 0;
125-
vxorps(ymm_zero, ymm_zero, ymm_zero);
126-
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
127-
vmovups(ymm_src, ptr[param1 + offset]);
128-
vmaxps(ymm_dst, ymm_zero, ymm_src);
129-
vmovups(ptr[param2 + offset], ymm_dst);
130-
offset += sizeof(float) * AVX_FLOAT_BLOCK;
131-
}
132-
int rest = num_ % AVX_FLOAT_BLOCK;
133-
if (rest >= 4) {
134-
vmovups(xmm_src, ptr[param1 + offset]);
135-
vmaxps(xmm_dst, xmm_zero, xmm_src);
136-
vmovups(ptr[param2 + offset], xmm_dst);
137-
offset += sizeof(float) * 4;
138-
rest -= 4;
139-
}
140-
if (rest >= 2) {
141-
vmovups(xmm_src, ptr[param1 + offset]);
142-
vmaxps(xmm_dst, xmm_zero, xmm_src);
143-
vmovq(ptr[param2 + offset], xmm_dst);
144-
offset += sizeof(float) * 2;
145-
rest -= 2;
146-
}
147-
if (rest > 0) {
148-
vmovups(xmm_src, ptr[param1 + offset]);
149-
vmaxps(xmm_dst, xmm_zero, xmm_src);
150-
vmovss(ptr[param2 + offset], xmm_dst);
151-
}
152-
ret();
153-
}
154-
155121
#define ALIGN32 __attribute__((aligned(32)))
156122
#define EXP_HIG 88.3762626647949f
157123
#define EXP_LOW -88.3762626647949f
@@ -207,18 +173,28 @@ static const float exp_float_consts[] ALIGN32 = {
207173
static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
208174
static int g_tmp_mem[16] ALIGN32 = {0};
209175

210-
bool VExpJitCode::init(int d) {
211-
return MayIUse(avx) && d == 8; // only 8 yet
176+
bool VActJitCode::init(int d, operand_type type) {
177+
bool ok = MayIUse(avx);
178+
if (type == operand_type::relu) {
179+
return ok;
180+
} else {
181+
return ok && d == 8; // only 8 yet
182+
}
212183
}
213184

214-
void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
215-
// use reg rax and ymm 2~5
216-
reg64_t reg_ptr_global = rax;
217-
ymm_t ymm_fx = ymm_t(2);
218-
ymm_t ymm_fy = ymm_t(3);
219-
ymm_t ymm_mask = ymm_t(4);
220-
ymm_t ymm_tmp = ymm_t(5);
185+
void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) {
186+
vmaxps(ymm_dst, ymm_zero, ymm_src);
187+
}
188+
189+
void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
190+
int fy_idx, int mask_idx, int tmp_idx) {
221191
assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore
192+
// check all idx can not equal
193+
ymm_t ymm_fx = ymm_t(fx_idx);
194+
ymm_t ymm_fy = ymm_t(fy_idx);
195+
ymm_t ymm_mask = ymm_t(mask_idx);
196+
ymm_t ymm_tmp = ymm_t(tmp_idx);
197+
reg64_t reg_ptr_global = rax;
222198
push(reg_ptr_global);
223199
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
224200
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
@@ -291,22 +267,11 @@ void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
291267
pop(reg_ptr_global);
292268
}
293269

294-
void VExpJitCode::generate() {
295-
int offset = 0;
296-
vmovups(ymm_src, ptr[param1 + offset]);
297-
exp_ymm(ymm_src, ymm_dst);
298-
vmovups(ptr[param2 + offset], ymm_dst);
299-
ret();
300-
}
301-
302-
bool VSigmoidJitCode::init(int d) {
303-
return MayIUse(avx) && d == 8; // only 8 yet
304-
}
305-
306-
void VSigmoidJitCode::sigmoid_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
307-
// use ymm2
270+
void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
271+
int fy_idx, int mask_idx, int tmp_idx) {
272+
// y = 1 / (1 + e^-x)
273+
ymm_t ymm_tmp = ymm_t(tmp_idx);
308274
reg64_t reg_ptr_global = rax;
309-
ymm_t ymm_tmp = ymm_t(2);
310275
push(reg_ptr_global);
311276
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
312277
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]);
@@ -315,38 +280,26 @@ void VSigmoidJitCode::sigmoid_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
315280
vmaxps(ymm_src, ymm_src, ymm_tmp);
316281
vxorps(ymm_tmp, ymm_tmp, ymm_tmp);
317282
vsubps(ymm_src, ymm_tmp, ymm_src);
318-
exp_ymm(ymm_src, ymm_dst);
283+
exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
319284
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
320285
vaddps(ymm_dst, ymm_dst, ymm_tmp);
321286
vdivps(ymm_dst, ymm_tmp, ymm_dst);
322287
pop(reg_ptr_global);
323288
}
324289

325-
void VSigmoidJitCode::generate() {
326-
int offset = 0;
327-
vmovups(ymm_src, ptr[param1 + offset]);
328-
sigmoid_ymm(ymm_src, ymm_dst);
329-
vmovups(ptr[param2 + offset], ymm_dst);
330-
ret();
331-
}
332-
333-
bool VTanhJitCode::init(int d) {
334-
return MayIUse(avx) && d == 8; // only 8 yet
335-
}
336-
337-
void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
290+
void VActJitCode::tanh_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
291+
int fy_idx, int mask_idx, int tmp_idx) {
338292
// y = 2 / (1 + e^(-2x)) - 1
339-
// use ymm2, ymm3
293+
ymm_t ymm_tmp = ymm_t(tmp_idx);
294+
ymm_t ymm_zero = ymm_t(mask_idx);
340295
reg64_t reg_ptr_global = rax;
341-
ymm_t ymm_tmp = ymm_t(2);
342-
ymm_t ymm_zero = ymm_t(3);
343296
push(reg_ptr_global);
344297
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
345298
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
346299
vxorps(ymm_zero, ymm_zero, ymm_zero);
347300
vsubps(ymm_tmp, ymm_zero, ymm_tmp);
348301
vmulps(ymm_src, ymm_src, ymm_tmp);
349-
exp_ymm(ymm_src, ymm_dst);
302+
exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx);
350303
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]);
351304
vaddps(ymm_dst, ymm_dst, ymm_tmp);
352305
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]);
@@ -356,11 +309,61 @@ void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
356309
pop(reg_ptr_global);
357310
}
358311

359-
void VTanhJitCode::generate() {
312+
void VActJitCode::generate() {
313+
xmm_t xmm_zero = xmm_t(2);
314+
ymm_t ymm_zero = ymm_t(2);
315+
if (type_ == operand_type::relu) {
316+
vxorps(ymm_zero, ymm_zero, ymm_zero);
317+
}
360318
int offset = 0;
361-
vmovups(ymm_src, ptr[param1 + offset]);
362-
vtanh_ymm(ymm_src, ymm_dst);
363-
vmovups(ptr[param2 + offset], ymm_dst);
319+
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
320+
vmovups(ymm_src, ptr[param1 + offset]);
321+
switch (type_) {
322+
case operand_type::relu:
323+
relu_ymm(ymm_dst, ymm_src, ymm_zero);
324+
break;
325+
case operand_type::exp:
326+
exp_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
327+
break;
328+
case operand_type::sigmoid:
329+
sigmoid_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
330+
break;
331+
case operand_type::tanh:
332+
tanh_ymm(ymm_dst, ymm_src, 2, 3, 4, 5);
333+
break;
334+
case operand_type::identity:
335+
break;
336+
default:
337+
break;
338+
}
339+
vmovups(ptr[param2 + offset], ymm_dst);
340+
offset += sizeof(float) * AVX_FLOAT_BLOCK;
341+
}
342+
if (type_ != operand_type::relu) {
343+
// TODO(TJ): remove me
344+
ret();
345+
return;
346+
}
347+
int rest = num_ % AVX_FLOAT_BLOCK;
348+
if (rest >= 4) {
349+
vmovups(xmm_src, ptr[param1 + offset]);
350+
vmaxps(xmm_dst, xmm_zero, xmm_src);
351+
vmovups(ptr[param2 + offset], xmm_dst);
352+
offset += sizeof(float) * 4;
353+
rest -= 4;
354+
}
355+
if (rest >= 2) {
356+
vmovups(xmm_src, ptr[param1 + offset]);
357+
vmaxps(xmm_dst, xmm_zero, xmm_src);
358+
vmovq(ptr[param2 + offset], xmm_dst);
359+
offset += sizeof(float) * 2;
360+
rest -= 2;
361+
}
362+
if (rest > 0) {
363+
vmovups(xmm_src, ptr[param1 + offset]);
364+
vmaxps(xmm_dst, xmm_zero, xmm_src);
365+
vmovss(ptr[param2 + offset], xmm_dst);
366+
}
364367
ret();
365368
}
366369

paddle/fluid/operators/math/jit_code.h

Lines changed: 54 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,16 @@ using ymm_t = const Xbyak::Ymm;
2929
using zmm_t = const Xbyak::Zmm;
3030
using Label = Xbyak::Label;
3131

32-
typedef enum { mul = 0, add } operand_type;
32+
typedef enum {
33+
mul = 0,
34+
add,
35+
sub,
36+
relu,
37+
exp,
38+
sigmoid,
39+
tanh,
40+
identity
41+
} operand_type;
3342

3443
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
3544
class VXXJitCode : public JitCode {
@@ -85,87 +94,65 @@ class VXXJitCode : public JitCode {
8594
ymm_t ymm_zero = ymm_t(3);
8695
};
8796

88-
class ReluJitCode : public JitCode {
97+
class VActJitCode : public JitCode {
8998
public:
90-
DECLARE_JIT_CODE(ReluJitCode);
91-
explicit ReluJitCode(int d, size_t code_size = 256 * 1024,
92-
void* code_ptr = nullptr)
93-
: JitCode(code_size, code_ptr), num_(d) {}
94-
static bool init(int d);
95-
void generate() override;
96-
97-
private:
98-
int num_;
99-
reg64_t param1{abi_param1};
100-
reg64_t param2{abi_param2};
101-
102-
xmm_t xmm_zero = xmm_t(0);
103-
xmm_t xmm_src = xmm_t(1);
104-
xmm_t xmm_dst = xmm_t(1);
105-
106-
ymm_t ymm_zero = ymm_t(0);
107-
ymm_t ymm_src = ymm_t(1);
108-
ymm_t ymm_dst = ymm_t(1);
109-
};
99+
const char* name() const override {
100+
std::string base = "VActJitCode";
101+
switch (type_) {
102+
case operand_type::relu:
103+
base += "_Relu";
104+
break;
105+
case operand_type::exp:
106+
base += "_Exp";
107+
break;
108+
case operand_type::sigmoid:
109+
base += "_Sigmoid";
110+
break;
111+
case operand_type::tanh:
112+
base += "_Tanh";
113+
break;
114+
case operand_type::identity:
115+
base += "_Identity";
116+
break;
117+
default:
118+
break;
119+
}
120+
return base.c_str();
121+
}
110122

111-
class VExpJitCode : public JitCode {
112-
public:
113-
DECLARE_JIT_CODE(VExpJitCode);
114-
explicit VExpJitCode(int d, size_t code_size = 256 * 1024,
123+
explicit VActJitCode(int d, operand_type type, size_t code_size = 256 * 1024,
115124
void* code_ptr = nullptr)
116-
: JitCode(code_size, code_ptr), num_(d) {}
117-
static bool init(int d);
125+
: JitCode(code_size, code_ptr), num_(d), type_(type) {}
126+
static bool init(int d, operand_type type);
118127
void generate() override;
119128

120129
protected:
121-
// compute exp with ymm
122-
void exp_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst);
130+
// compute relu with ymm
131+
void relu_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src,
132+
const Xbyak::Ymm& zero);
123133

124-
private:
125-
int num_;
126-
reg64_t param1{abi_param1};
127-
reg64_t param2{abi_param2};
128-
ymm_t ymm_src = ymm_t(0);
129-
ymm_t ymm_dst = ymm_t(1);
130-
};
131-
132-
class VSigmoidJitCode : public VExpJitCode {
133-
public:
134-
DECLARE_JIT_CODE(VSigmoidJitCode);
135-
explicit VSigmoidJitCode(int d, size_t code_size = 256 * 1024,
136-
void* code_ptr = nullptr)
137-
: VExpJitCode(d, code_size, code_ptr), num_(d) {}
138-
static bool init(int d);
139-
void generate() override;
134+
// compute exp with ymm
135+
void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
136+
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
140137

141138
// compute sigmoid with ymm
142-
void sigmoid_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst);
139+
void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
140+
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
143141

144-
private:
145-
int num_;
146-
reg64_t param1{abi_param1};
147-
reg64_t param2{abi_param2};
148-
ymm_t ymm_src = ymm_t(0);
149-
ymm_t ymm_dst = ymm_t(1);
150-
};
142+
// compute tanh with ymm
143+
void tanh_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
144+
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
151145

152-
class VTanhJitCode : public VExpJitCode {
153-
public:
154-
DECLARE_JIT_CODE(VTanhJitCode);
155-
explicit VTanhJitCode(int d, size_t code_size = 256 * 1024,
156-
void* code_ptr = nullptr)
157-
: VExpJitCode(d, code_size, code_ptr), num_(d) {}
158-
static bool init(int d);
159-
void generate() override;
160-
161-
// compute sigmoid with ymm
162-
void vtanh_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst);
163-
164-
private:
146+
protected:
165147
int num_;
148+
operand_type type_;
166149
reg64_t param1{abi_param1};
167150
reg64_t param2{abi_param2};
151+
152+
xmm_t xmm_src = xmm_t(0);
168153
ymm_t ymm_src = ymm_t(0);
154+
155+
xmm_t xmm_dst = xmm_t(1);
169156
ymm_t ymm_dst = ymm_t(1);
170157
};
171158

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,8 @@ class VReluKernelImpl : public VReluKernel<T> {
352352
size_t sz = 96 /* init size */ +
353353
d / AVX_FLOAT_BLOCK * 4 /* instructions */ *
354354
8 /* average bytes for each instruction */;
355-
jitcode_.reset(new gen::ReluJitCode(d, sz > 4096 ? sz : 4096));
355+
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::relu,
356+
sz > 4096 ? sz : 4096));
356357
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
357358
return;
358359
}
@@ -366,14 +367,14 @@ class VReluKernelImpl : public VReluKernel<T> {
366367
#ifdef PADDLE_WITH_XBYAK
367368

368369
private:
369-
std::unique_ptr<gen::ReluJitCode> jitcode_{nullptr};
370+
std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
370371
#endif
371372
};
372373

373374
#ifdef PADDLE_WITH_XBYAK
374375
template <>
375376
bool VReluKernelImpl<float>::useJIT(int d) {
376-
return gen::ReluJitCode::init(d);
377+
return gen::VActJitCode::init(d, gen::operand_type::relu);
377378
}
378379
#endif
379380

0 commit comments

Comments
 (0)