Skip to content

Commit c2cfb03

Browse files
committed
add lstm jitcode
1 parent 8bc1c5d commit c2cfb03

File tree

4 files changed

+198
-17
lines changed

4 files changed

+198
-17
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/math/jit_code.h"
16+
#include <stddef.h> // offsetof
1617
#include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me
1718

1819
namespace paddle {
@@ -210,6 +211,54 @@ void VActJitCode::generate() {
210211
ret();
211212
}
212213

214+
bool LSTMJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
215+
216+
void LSTMJitCode::generate() {
217+
reg64_t reg_ptr_gates = rax;
218+
reg64_t reg_ptr_ct_1 = r9;
219+
reg64_t reg_ptr_ct = r10;
220+
reg64_t reg_ptr_ht = r11;
221+
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
222+
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
223+
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
224+
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
225+
226+
int offset = 0;
227+
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
228+
/* C_t = C_t-1 * fgated + cand_gated * igated*/
229+
// c
230+
vmovups(ymm_src, ptr[reg_ptr_gates + offset]);
231+
act<ymm_t>(ymm_c, ymm_src, act_cand_);
232+
// i
233+
vmovups(ymm_src, ptr[reg_ptr_gates + offset + num_]);
234+
act<ymm_t>(ymm_i, ymm_src, act_gate_);
235+
vmulps(ymm_c, ymm_c, ymm_i);
236+
if (first_) {
237+
// f
238+
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 2 * num_]);
239+
act<ymm_t>(ymm_f, ymm_src, act_gate_);
240+
vmovups(ymm_i, ptr[reg_ptr_ct_1 + offset]);
241+
vmulps(ymm_f, ymm_f, ymm_i);
242+
vaddps(ymm_f, ymm_f, ymm_c);
243+
}
244+
/* H_t = act_cell(C_t) * ogated */
245+
ymm_t ymm_ct = first_ ? ymm_c : ymm_f;
246+
ymm_t ymm_o = first_ ? ymm_f : ymm_c;
247+
ymm_t ymm_tmp = ymm_i;
248+
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
249+
vmovups(ymm_src, ptr[reg_ptr_gates + offset + 3 * num_]);
250+
act<ymm_t>(ymm_o, ymm_src, act_gate_);
251+
vmulps(ymm_o, ymm_tmp, ymm_o);
252+
// save ct and ht
253+
vmovups(ptr[reg_ptr_ct + offset], ymm_ct);
254+
vmovups(ptr[reg_ptr_ht + offset], ymm_o);
255+
256+
offset += sizeof(float) * YMM_FLOAT_BLOCK;
257+
}
258+
259+
ret();
260+
}
261+
213262
} // namespace gen
214263
} // namespace jitkernel
215264
} // namespace math

paddle/fluid/operators/math/jit_code.h

Lines changed: 94 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616

1717
#include <string>
1818
#include "paddle/fluid/operators/math/jit_gen.h"
19+
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
1920
#include "paddle/fluid/platform/cpu_info.h"
2021

2122
namespace paddle {
@@ -46,14 +47,6 @@ extern const float exp_float_consts[];
4647
extern const int exp_int_0x7f[];
4748
extern int g_tmp_mem[];
4849

49-
// TODO(TJ): move these to some proper place
50-
#define SIGMOID_THRESHOLD_MIN -40.0
51-
#define SIGMOID_THRESHOLD_MAX 13.0
52-
#define EXP_MAX_INPUT 40.0
53-
#define XMM_FLOAT_BLOCK 4
54-
#define YMM_FLOAT_BLOCK 8
55-
#define ZMM_FLOAT_BLOCK 16
56-
5750
#define ALIGN32 __attribute__((aligned(32)))
5851
#define EXP_HIG 88.3762626647949f
5952
#define EXP_LOW -88.3762626647949f
@@ -322,6 +315,99 @@ class VActJitCode : public JitCode {
322315
ymm_t ymm_dst = ymm_t(1);
323316
};
324317

318+
class LSTMJitCode : public VActJitCode {
319+
public:
320+
const char* name() const override {
321+
std::string base = "LSTMJitCode";
322+
auto AddTypeStr = [&](operand_type type) {
323+
switch (type) {
324+
case operand_type::relu:
325+
base += "_Relu";
326+
break;
327+
case operand_type::exp:
328+
base += "_Exp";
329+
break;
330+
case operand_type::sigmoid:
331+
base += "_Sigmoid";
332+
break;
333+
case operand_type::tanh:
334+
base += "_Tanh";
335+
break;
336+
case operand_type::identity:
337+
base += "_Identity";
338+
break;
339+
default:
340+
break;
341+
}
342+
};
343+
if (first_) {
344+
base += "_C1H1";
345+
}
346+
AddTypeStr(act_gate_);
347+
AddTypeStr(act_cand_);
348+
AddTypeStr(act_cell_);
349+
return base.c_str();
350+
}
351+
352+
explicit LSTMJitCode(int d, bool first, operand_type act_gate,
353+
operand_type act_cand, operand_type act_cell,
354+
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
355+
: VActJitCode(d, act_gate, code_size, code_ptr),
356+
num_(d),
357+
first_(first),
358+
act_gate_(act_gate),
359+
act_cand_(act_cand),
360+
act_cell_(act_cell) {}
361+
static bool init(int d);
362+
void generate() override;
363+
364+
protected:
365+
int num_;
366+
bool first_;
367+
operand_type act_gate_;
368+
operand_type act_cand_;
369+
operand_type act_cell_;
370+
reg64_t param1{abi_param1};
371+
372+
xmm_t xmm_src = xmm_t(0);
373+
xmm_t xmm_c = xmm_t(1);
374+
xmm_t xmm_i = xmm_t(2);
375+
xmm_t xmm_f = xmm_t(3);
376+
377+
ymm_t ymm_src = ymm_t(0);
378+
ymm_t ymm_c = ymm_t(1);
379+
ymm_t ymm_i = ymm_t(2);
380+
ymm_t ymm_f = ymm_t(3);
381+
382+
template <typename JMM>
383+
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT
384+
// use 15
385+
JMM zero = JMM(15);
386+
if (type_ == operand_type::relu) {
387+
vxorps(zero, zero, zero);
388+
}
389+
switch (type) {
390+
case operand_type::relu:
391+
relu_jmm<JMM>(dst, src, zero);
392+
break;
393+
case operand_type::exp:
394+
exp_jmm<JMM>(dst, src, 2, 3, 4, 5);
395+
break;
396+
case operand_type::sigmoid:
397+
sigmoid_jmm<JMM>(dst, src, 2, 3, 4, 5);
398+
break;
399+
case operand_type::tanh:
400+
tanh_jmm<JMM>(dst, src, 2, 3, 4, 5);
401+
break;
402+
case operand_type::identity:
403+
break;
404+
default:
405+
// throw error
406+
break;
407+
}
408+
}
409+
};
410+
325411
} // namespace gen
326412
} // namespace jitkernel
327413
} // namespace math

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <memory> // for shared_ptr
1818
#include <string>
1919
#include <unordered_map>
20+
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
2021
#include "paddle/fluid/platform/cpu_info.h"
2122
#include "paddle/fluid/platform/macros.h"
2223

@@ -26,14 +27,7 @@ namespace operators {
2627
namespace math {
2728
namespace jitkernel {
2829

29-
// TODO(TJ): move these to some proper place
30-
#define SIGMOID_THRESHOLD_MIN -40.0
31-
#define SIGMOID_THRESHOLD_MAX 13.0
32-
#define EXP_MAX_INPUT 40.0
33-
#define XMM_FLOAT_BLOCK 4
34-
#define YMM_FLOAT_BLOCK 8
35-
#define ZMM_FLOAT_BLOCK 16
36-
30+
// TODO(TJ): remove me
3731
typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
3832

3933
class Kernel {
@@ -124,10 +118,13 @@ class LSTMKernel : public Kernel {
124118
const T *wp_data = nullptr,
125119
T *checked = nullptr) const = 0;
126120

127-
// compute c1 and h1 without c0 or h0
128121
virtual void ComputeC1H1(T *gates, T *ct, T *ht,
129122
/* below only used in peephole*/
130123
const T *wp_data = nullptr) const = 0;
124+
125+
// void (*ComputeCtHt)(lstm_t *);
126+
// // compute c1 and h1 without c0 or h0
127+
// void (*ComputeC1H1)(lstm_t *);
131128
};
132129

133130
template <typename T>
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <string>
17+
#include <type_traits>
18+
19+
namespace paddle {
20+
namespace operators {
21+
namespace math {
22+
namespace jitkernel {
23+
24+
#define SIGMOID_THRESHOLD_MIN -40.0
25+
#define SIGMOID_THRESHOLD_MAX 13.0
26+
#define EXP_MAX_INPUT 40.0
27+
#define XMM_FLOAT_BLOCK 4
28+
#define YMM_FLOAT_BLOCK 8
29+
#define ZMM_FLOAT_BLOCK 16
30+
31+
typedef struct {
32+
void* gates; // gates: W_ch, W_ih, W_fh, W_oh
33+
const void* ct_1;
34+
void* ct;
35+
void* ht;
36+
/* below only used in peephole*/
37+
const void* wp_data{nullptr};
38+
void* checked{nullptr};
39+
} lstm_t;
40+
41+
typedef struct {
42+
int d;
43+
std::string act_gate, act_cand, act_cell;
44+
} lstm_attr_t;
45+
46+
} // namespace jitkernel
47+
} // namespace math
48+
} // namespace operators
49+
} // namespace paddle

0 commit comments

Comments
 (0)