Skip to content

Commit e3b61cf

Browse files
committed
init gru jitcode and fix lstm jitcode
test=develop
1 parent 0f25446 commit e3b61cf

File tree

3 files changed

+170
-42
lines changed

3 files changed

+170
-42
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ void VActJitCode::generate() {
214214
bool LSTMJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
215215

216216
void LSTMJitCode::generate() {
217+
if (use_peephole_) {
218+
preCode();
219+
}
217220
reg64_t reg_ptr_gates = rax;
218221
reg64_t reg_ptr_ct_1 = r9;
219222
reg64_t reg_ptr_ct = r10;
@@ -224,18 +227,19 @@ void LSTMJitCode::generate() {
224227
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
225228

226229
int offset = 0;
230+
int d = num_ * sizeof(float);
227231
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
228232
/* C_t = C_t-1 * fgated + cand_gated * igated*/
229233
// c
230234
vmovups(ymm_src, ptr[reg_ptr_gates + offset]);
231235
act<ymm_t>(ymm_c, ymm_src, act_cand_);
232236
// i
233-
vmovups(ymm_src, ptr[reg_ptr_gates + offset + num_]);
237+
vmovups(ymm_src, ptr[reg_ptr_gates + offset + d]);
234238
act<ymm_t>(ymm_i, ymm_src, act_gate_);
235239
vmulps(ymm_c, ymm_c, ymm_i);
236240
if (!compute_c1h1_) {
237241
// f
238-
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * num_]);
242+
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * d]);
239243
act<ymm_t>(ymm_f, ymm_src, act_gate_);
240244
vmovups(ymm_i, ptr[reg_ptr_ct_1 + offset]);
241245
vmulps(ymm_f, ymm_f, ymm_i);
@@ -245,20 +249,36 @@ void LSTMJitCode::generate() {
245249
ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
246250
ymm_t ymm_o = compute_c1h1_ ? ymm_f : ymm_c;
247251
ymm_t ymm_tmp = ymm_i;
252+
vmovups(ptr[reg_ptr_ct + offset], ymm_ct); // save ct
248253
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
249-
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * num_]);
254+
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * d]);
250255
act<ymm_t>(ymm_o, ymm_src, act_gate_);
251256
vmulps(ymm_o, ymm_tmp, ymm_o);
252-
// save ct and ht
253-
vmovups(ptr[reg_ptr_ct + offset], ymm_ct);
254-
vmovups(ptr[reg_ptr_ht + offset], ymm_o);
255-
257+
vmovups(ptr[reg_ptr_ht + offset], ymm_o); // save ht
256258
offset += sizeof(float) * YMM_FLOAT_BLOCK;
257259
}
258260

259-
ret();
261+
if (use_peephole_) {
262+
postCode();
263+
} else {
264+
ret();
265+
}
260266
}
261267

268+
bool GRUJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
269+
270+
void GRUJitCode::generate() {
271+
reg64_t reg_ptr_gates = rax;
272+
reg64_t reg_ptr_ct_1 = r9;
273+
reg64_t reg_ptr_ct = r10;
274+
reg64_t reg_ptr_ht = r11;
275+
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
276+
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
277+
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
278+
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
279+
280+
ret();
281+
}
262282
} // namespace gen
263283
} // namespace jitkernel
264284
} // namespace math

paddle/fluid/operators/math/jit_code.h

Lines changed: 109 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,34 @@ class VActJitCode : public JitCode {
302302
pop(reg_ptr_global);
303303
}
304304

305+
template <typename JMM>
306+
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT
307+
// use 15
308+
JMM zero = JMM(15);
309+
if (type_ == operand_type::relu) {
310+
vxorps(zero, zero, zero);
311+
}
312+
switch (type) {
313+
case operand_type::relu:
314+
relu_jmm<JMM>(dst, src, zero);
315+
break;
316+
case operand_type::exp:
317+
exp_jmm<JMM>(dst, src, 2, 3, 4, 5);
318+
break;
319+
case operand_type::sigmoid:
320+
sigmoid_jmm<JMM>(dst, src, 2, 3, 4, 5);
321+
break;
322+
case operand_type::tanh:
323+
tanh_jmm<JMM>(dst, src, 2, 3, 4, 5);
324+
break;
325+
case operand_type::identity:
326+
break;
327+
default:
328+
// throw error
329+
break;
330+
}
331+
}
332+
305333
protected:
306334
int num_;
307335
operand_type type_;
@@ -386,44 +414,94 @@ class LSTMJitCode : public VActJitCode {
386414
operand_type act_cand_;
387415
operand_type act_cell_;
388416
reg64_t param1{abi_param1};
389-
390417
xmm_t xmm_src = xmm_t(0);
391418
xmm_t xmm_c = xmm_t(1);
392-
xmm_t xmm_i = xmm_t(2);
393-
xmm_t xmm_f = xmm_t(3);
419+
xmm_t xmm_i = xmm_t(6);
420+
xmm_t xmm_f = xmm_t(7);
394421

395422
ymm_t ymm_src = ymm_t(0);
396-
ymm_t ymm_c = ymm_t(1);
397-
ymm_t ymm_i = ymm_t(2);
398-
ymm_t ymm_f = ymm_t(3);
423+
ymm_t ymm_c = ymm_t(1); // 2~5 for act
424+
ymm_t ymm_i = ymm_t(6);
425+
ymm_t ymm_f = ymm_t(7);
426+
};
399427

400-
template <typename JMM>
401-
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT
402-
// use 15
403-
JMM zero = JMM(15);
404-
if (type_ == operand_type::relu) {
405-
vxorps(zero, zero, zero);
406-
}
407-
switch (type) {
408-
case operand_type::relu:
409-
relu_jmm<JMM>(dst, src, zero);
410-
break;
411-
case operand_type::exp:
412-
exp_jmm<JMM>(dst, src, 2, 3, 4, 5);
413-
break;
414-
case operand_type::sigmoid:
415-
sigmoid_jmm<JMM>(dst, src, 2, 3, 4, 5);
416-
break;
417-
case operand_type::tanh:
418-
tanh_jmm<JMM>(dst, src, 2, 3, 4, 5);
419-
break;
420-
case operand_type::identity:
421-
break;
422-
default:
423-
// throw error
424-
break;
428+
class GRUJitCode : public VActJitCode {
429+
public:
430+
const char* name() const override {
431+
std::string base = "GRUJitCode";
432+
if (id_ == 0) {
433+
base += "_H1";
434+
} else if (id_ == 1) {
435+
base += "_HtPart1";
436+
} else if (id_ == 2) {
437+
base += "_HtPart2";
425438
}
439+
auto AddTypeStr = [&](operand_type type) {
440+
switch (type) {
441+
case operand_type::relu:
442+
base += "_Relu";
443+
break;
444+
case operand_type::exp:
445+
base += "_Exp";
446+
break;
447+
case operand_type::sigmoid:
448+
base += "_Sigmoid";
449+
break;
450+
case operand_type::tanh:
451+
base += "_Tanh";
452+
break;
453+
case operand_type::identity:
454+
base += "_Identity";
455+
break;
456+
default:
457+
break;
458+
}
459+
};
460+
AddTypeStr(act_gate_);
461+
AddTypeStr(act_cand_);
462+
return base.c_str();
426463
}
464+
465+
explicit GRUJitCode(int id, const gru_attr_t& attr,
466+
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
467+
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
468+
code_ptr),
469+
id_(id) {
470+
auto typeExchange = [](const std::string& type) -> gen::operand_type {
471+
if (type == "sigmoid") {
472+
return operand_type::sigmoid;
473+
} else if (type == "relu") {
474+
return operand_type::relu;
475+
} else if (type == "tanh") {
476+
return operand_type::tanh;
477+
} else if (type == "identity" || type == "") {
478+
return operand_type::identity;
479+
} // else throw error
480+
return operand_type::identity;
481+
};
482+
num_ = attr.d;
483+
act_gate_ = typeExchange(attr.act_gate);
484+
act_cand_ = typeExchange(attr.act_cand);
485+
}
486+
static bool init(int d);
487+
void generate() override;
488+
489+
protected:
490+
int id_;
491+
int num_;
492+
operand_type act_gate_;
493+
operand_type act_cand_;
494+
reg64_t param1{abi_param1};
495+
496+
xmm_t xmm_src = xmm_t(0);
497+
xmm_t xmm_c = xmm_t(1);
498+
xmm_t xmm_i = xmm_t(6);
499+
xmm_t xmm_f = xmm_t(7);
500+
501+
ymm_t ymm_src = ymm_t(0);
502+
ymm_t ymm_c = ymm_t(1);
503+
ymm_t ymm_i = ymm_t(6);
504+
ymm_t ymm_f = ymm_t(7);
427505
};
428506

429507
#ifdef PADDLE_WITH_MKLDNN

paddle/fluid/operators/math/jit_kernel_rnn.cc

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class LSTMKernelImpl : public LSTMKernel<T> {
4040
explicit LSTMKernelImpl(const lstm_attr_t& attr) : LSTMKernel<T>() {
4141
#ifdef PADDLE_WITH_XBYAK
4242
if (useJIT(attr.d)) {
43-
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 84 * 8; // should change
43+
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8;
4444
jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096));
4545
this->ComputeCtHt =
4646
jitcode0_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
@@ -66,7 +66,7 @@ class LSTMKernelImpl : public LSTMKernel<T> {
6666
#ifdef PADDLE_WITH_XBYAK
6767
template <>
6868
bool LSTMKernelImpl<float>::useJIT(int d) {
69-
return false; // not ready yet gen::LSTMJitCode::init(d);
69+
return gen::LSTMJitCode::init(d);
7070
}
7171
#endif
7272

@@ -82,7 +82,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
8282
explicit PeepholeKernelImpl(const lstm_attr_t& attr) : LSTMKernel<T>() {
8383
#ifdef PADDLE_WITH_XBYAK
8484
if (useJIT(attr.d)) {
85-
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 84 * 8; // should change
85+
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 4 * 8;
8686
jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096));
8787
this->ComputeCtHt =
8888
jitcode0_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
@@ -175,12 +175,42 @@ class GRUKernelImpl : public GRUKernel<T> {
175175
static inline bool useJIT(int d) { return false; }
176176
static inline bool useMKL(int d) { return false; }
177177
explicit GRUKernelImpl(const gru_attr_t& attr) : GRUKernel<T>() {
178+
#ifdef PADDLE_WITH_XBYAK
179+
if (useJIT(attr.d)) {
180+
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 84 * 8; // should change
181+
jitcode0_.reset(new gen::GRUJitCode(0, attr, sz > 4096 ? sz : 4096));
182+
this->ComputeH1 =
183+
jitcode0_->getCode<void (*)(gru_t*, const gru_attr_t*)>();
184+
185+
jitcode1_.reset(new gen::GRUJitCode(1, attr, sz > 4096 ? sz : 4096));
186+
this->ComputeHtPart1 =
187+
jitcode1_->getCode<void (*)(gru_t*, const gru_attr_t*)>();
188+
189+
jitcode2_.reset(new gen::GRUJitCode(2, attr, sz > 4096 ? sz : 4096));
190+
this->ComputeHtPart2 =
191+
jitcode1_->getCode<void (*)(gru_t*, const gru_attr_t*)>();
192+
return;
193+
}
194+
#endif
178195
this->ComputeH1 = refer::GRUH1<T>;
179196
this->ComputeHtPart1 = refer::GRUHtPart1<T>;
180197
this->ComputeHtPart2 = refer::GRUHtPart2<T>;
181198
}
199+
#ifdef PADDLE_WITH_XBYAK
200+
201+
private:
202+
std::unique_ptr<gen::GRUJitCode> jitcode0_{nullptr}, jitcode1_{nullptr},
203+
jitcode2_{nullptr};
204+
#endif
182205
};
183206

207+
#ifdef PADDLE_WITH_XBYAK
208+
template <>
209+
bool GRUKernelImpl<float>::useJIT(int d) {
210+
return false; // jitcode not ready yet
211+
}
212+
#endif
213+
184214
#define JITKERNEL_DEFINE_NAME_GRU(ker_key, ker_class) \
185215
template <> \
186216
std::string ker_class##Impl<float>::name(const gru_attr_t& attr) { \

0 commit comments

Comments
 (0)