Skip to content

Commit 3562051

Browse files
committed
add gru refer code and remove redundant avx code
test=develop
1 parent f913860 commit 3562051

File tree

6 files changed

+163
-428
lines changed

6 files changed

+163
-428
lines changed

paddle/fluid/operators/fused/fusion_gru_op.cc

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -183,24 +183,27 @@ class FusionGRUKernel : public framework::OpKernel<T> {
183183
const int total_T = x_dims[0]; \
184184
const int D3 = wh_dims[1]
185185

186-
#define INIT_OTHER_DEFINES \
187-
auto* h0 = ctx.Input<Tensor>("H0"); \
188-
auto* wx = ctx.Input<Tensor>("WeightX"); \
189-
auto* bias = ctx.Input<Tensor>("Bias"); \
190-
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
191-
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
192-
const int M = x_dims[1]; \
193-
const int D = wh_dims[0]; \
194-
const int D2 = D * 2; \
195-
const auto& ker = math::jitkernel::KernelPool::Instance() \
196-
.template Get<math::jitkernel::GRUKernel<T>, \
197-
const std::string&, const std::string&>( \
198-
ctx.Attr<std::string>("gate_activation"), \
199-
ctx.Attr<std::string>("activation"), D); \
200-
const T* x_data = x->data<T>(); \
201-
const T* wx_data = wx->data<T>(); \
202-
const T* wh_data = wh->data<T>(); \
203-
auto place = ctx.GetPlace(); \
186+
#define INIT_OTHER_DEFINES \
187+
auto* h0 = ctx.Input<Tensor>("H0"); \
188+
auto* wx = ctx.Input<Tensor>("WeightX"); \
189+
auto* bias = ctx.Input<Tensor>("Bias"); \
190+
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
191+
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
192+
const int M = x_dims[1]; \
193+
const int D = wh_dims[0]; \
194+
const int D2 = D * 2; \
195+
const math::jitkernel::gru_attr_t attr( \
196+
D, ctx.Attr<std::string>("gate_activation"), \
197+
ctx.Attr<std::string>("activation")); \
198+
math::jitkernel::gru_t one_step; \
199+
const auto& ker = \
200+
math::jitkernel::KernelPool::Instance() \
201+
.template Get<math::jitkernel::GRUKernel<T>, \
202+
const math::jitkernel::gru_attr_t&>(attr); \
203+
const T* x_data = x->data<T>(); \
204+
const T* wx_data = wx->data<T>(); \
205+
const T* wh_data = wh->data<T>(); \
206+
auto place = ctx.GetPlace(); \
204207
T* xx_data = xx->mutable_data<T>(place)
205208

206209
void SeqCompute(const framework::ExecutionContext& ctx) const {
@@ -237,7 +240,9 @@ class FusionGRUKernel : public framework::OpKernel<T> {
237240
if (h0_data) {
238241
prev_hidden_data = h0_data + bid * D;
239242
} else {
240-
ker->ComputeH1(xx_data, hidden_out_data);
243+
one_step.gates = xx_data;
244+
one_step.ht = hidden_out_data;
245+
ker->ComputeH1(&one_step, &attr);
241246
prev_hidden_data = hidden_out_data;
242247
tstart = 1;
243248
move_step();
@@ -247,12 +252,15 @@ class FusionGRUKernel : public framework::OpKernel<T> {
247252
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
248253
prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
249254
D3);
250-
ker->ComputeHtPart1(xx_data, prev_hidden_data, hidden_out_data);
255+
one_step.gates = xx_data;
256+
one_step.ht_1 = prev_hidden_data;
257+
one_step.ht = hidden_out_data;
258+
ker->ComputeHtPart1(&one_step, &attr);
251259
// gemm rt * Ws
252260
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
253261
hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
254262
xx_data + D2, D3);
255-
ker->ComputeHtPart2(xx_data, prev_hidden_data, hidden_out_data);
263+
ker->ComputeHtPart2(&one_step, &attr);
256264
// save prev
257265
prev_hidden_data = hidden_out_data;
258266
move_step();
@@ -314,7 +322,9 @@ class FusionGRUKernel : public framework::OpKernel<T> {
314322
T* cur_out_data = batched_out_data;
315323
// W: {W_update, W_reset; W_state}
316324
for (int i = 0; i < max_bs; ++i) {
317-
ker->ComputeH1(cur_in_data, cur_out_data);
325+
one_step.gates = cur_in_data;
326+
one_step.ht = cur_out_data;
327+
ker->ComputeH1(&one_step, &attr);
318328
// add offset
319329
cur_in_data += D3;
320330
cur_out_data += D;
@@ -339,8 +349,11 @@ class FusionGRUKernel : public framework::OpKernel<T> {
339349
T* cur_out_data = batched_out_data;
340350
T* cur_prev_hidden_data = prev_hidden_data;
341351
for (int i = 0; i < cur_bs; ++i) {
342-
ker->ComputeHtPart1(cur_batched_data, cur_prev_hidden_data,
343-
cur_out_data);
352+
one_step.gates = cur_batched_data;
353+
one_step.ht_1 = cur_prev_hidden_data;
354+
one_step.ht = cur_out_data;
355+
ker->ComputeHtPart1(&one_step, &attr);
356+
344357
cur_batched_data += D3;
345358
cur_prev_hidden_data += D;
346359
cur_out_data += D;
@@ -354,8 +367,10 @@ class FusionGRUKernel : public framework::OpKernel<T> {
354367

355368
cur_prev_hidden_data = prev_hidden_data;
356369
for (int i = 0; i < cur_bs; ++i) {
357-
ker->ComputeHtPart2(cur_batched_data, cur_prev_hidden_data,
358-
cur_out_data);
370+
one_step.gates = cur_batched_data;
371+
one_step.ht_1 = cur_prev_hidden_data;
372+
one_step.ht = cur_out_data;
373+
ker->ComputeHtPart2(&one_step, &attr);
359374
cur_batched_data += D3;
360375
cur_prev_hidden_data += D;
361376
cur_out_data += D;

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,18 @@ class VTanhKernel : public VActKernel<T> {};
122122
template <typename T>
123123
class LSTMKernel : public Kernel {
124124
public:
125-
void (*ComputeCtHt)(lstm_t *, const lstm_attr_t *);
126125
// compute c1 and h1 without c0 or h0
127126
void (*ComputeC1H1)(lstm_t *, const lstm_attr_t *);
127+
void (*ComputeCtHt)(lstm_t *, const lstm_attr_t *);
128128
};
129129

130130
template <typename T>
131131
class GRUKernel : public Kernel {
132132
public:
133133
// compute h1 without h0
134-
virtual void ComputeH1(T *gates, T *ht) const = 0;
135-
virtual void ComputeHtPart1(T *gates, const T *ht_1, T *ht) const = 0;
136-
virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0;
134+
void (*ComputeH1)(gru_t *, const gru_attr_t *);
135+
void (*ComputeHtPart1)(gru_t *, const gru_attr_t *);
136+
void (*ComputeHtPart2)(gru_t *, const gru_attr_t *);
137137
};
138138

139139
template <typename T>

paddle/fluid/operators/math/jit_kernel_exp.cc

Lines changed: 0 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@ limitations under the License. */
2525
#include "paddle/fluid/platform/dynload/mklml.h"
2626
#endif
2727

28-
#ifdef __AVX__
29-
#include <immintrin.h>
30-
#endif
31-
3228
namespace paddle {
3329
namespace operators {
3430
namespace math {
@@ -235,154 +231,6 @@ REGISTER_JITKERNEL(vexp, VExpKernel);
235231
REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel);
236232
REGISTER_JITKERNEL(vtanh, VTanhKernel);
237233

238-
namespace detail {
239-
240-
#ifdef __AVX__
241-
242-
#define ALIGN32 __attribute__((aligned(32)))
243-
244-
#define _PS256_CONST(Name, Val) \
245-
static const float _ps256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
246-
Val, Val, Val, Val}
247-
248-
#define _PI256_CONST(Name, Val) \
249-
static const int _pi256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
250-
Val, Val, Val, Val}
251-
252-
_PI256_CONST(0x7f, 0x7f);
253-
_PS256_CONST(one, 1.f);
254-
_PS256_CONST(0p5, 0.5f);
255-
_PS256_CONST(exp_hi, 88.3762626647949f);
256-
_PS256_CONST(exp_lo, -88.3762626647949f);
257-
_PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
258-
_PS256_CONST(cephes_exp_C1, 0.693359375);
259-
_PS256_CONST(cephes_exp_C2, -2.12194440e-4);
260-
_PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
261-
_PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
262-
_PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
263-
_PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
264-
_PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
265-
_PS256_CONST(cephes_exp_p5, 5.0000001201E-1);
266-
267-
typedef union imm_xmm_union {
268-
__m256i imm;
269-
__m128i xmm[2];
270-
} imm_xmm_union;
271-
272-
#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \
273-
{ \
274-
imm_xmm_union u ALIGN32; \
275-
u.imm = imm_; \
276-
xmm0_ = u.xmm[0]; \
277-
xmm1_ = u.xmm[1]; \
278-
}
279-
280-
#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \
281-
{ \
282-
imm_xmm_union u ALIGN32; \
283-
u.xmm[0] = xmm0_; \
284-
u.xmm[1] = xmm1_; \
285-
imm_ = u.imm; \
286-
}
287-
288-
#define AVX2_BITOP_USING_SSE2(fn) \
289-
static inline __m256i avx2_mm256_##fn(__m256i x, int y) { \
290-
/* use SSE2 to perform the bitop AVX2 */ \
291-
__m128i x1, x2; \
292-
__m256i ret; \
293-
COPY_IMM_TO_XMM(x, x1, x2); \
294-
x1 = _mm_##fn(x1, y); \
295-
x2 = _mm_##fn(x2, y); \
296-
COPY_XMM_TO_IMM(x1, x2, ret); \
297-
return ret; \
298-
}
299-
300-
#define AVX2_INTOP_USING_SSE2(fn) \
301-
static inline __m256i avx2_mm256_add_epi32(__m256i x, __m256i y) { \
302-
/* use SSE2 to perform the AVX2 integer operation */ \
303-
__m128i x1, x2; \
304-
__m128i y1, y2; \
305-
__m256i ret; \
306-
COPY_IMM_TO_XMM(x, x1, x2); \
307-
COPY_IMM_TO_XMM(y, y1, y2); \
308-
x1 = _mm_##fn(x1, y1); \
309-
x2 = _mm_##fn(x2, y2); \
310-
COPY_XMM_TO_IMM(x1, x2, ret); \
311-
return ret; \
312-
}
313-
314-
AVX2_BITOP_USING_SSE2(slli_epi32);
315-
AVX2_INTOP_USING_SSE2(add_epi32);
316-
317-
#define AVXEXP_BASE \
318-
__m256 tmp = _mm256_setzero_ps(), fx; \
319-
__m256 one = *reinterpret_cast<const __m256*>(_ps256_one); \
320-
__m256i imm0; \
321-
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi)); \
322-
x = _mm256_max_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_lo)); \
323-
/* express exp(x) as exp(g + n*log(2)) */ \
324-
fx = _mm256_mul_ps(x, \
325-
*reinterpret_cast<const __m256*>(_ps256_cephes_LOG2EF)); \
326-
fx = _mm256_add_ps(fx, *reinterpret_cast<const __m256*>(_ps256_0p5)); \
327-
tmp = _mm256_floor_ps(fx); \
328-
/* if greater, substract 1 */ \
329-
__m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); \
330-
mask = _mm256_and_ps(mask, one); \
331-
fx = _mm256_sub_ps(tmp, mask); \
332-
tmp = _mm256_mul_ps(fx, \
333-
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_C1)); \
334-
__m256 z = _mm256_mul_ps( \
335-
fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C2)); \
336-
x = _mm256_sub_ps(x, tmp); \
337-
x = _mm256_sub_ps(x, z); \
338-
z = _mm256_mul_ps(x, x); \
339-
__m256 y = *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p0); \
340-
y = _mm256_mul_ps(y, x); \
341-
y = _mm256_add_ps(y, \
342-
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p1)); \
343-
y = _mm256_mul_ps(y, x); \
344-
y = _mm256_add_ps(y, \
345-
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p2)); \
346-
y = _mm256_mul_ps(y, x); \
347-
y = _mm256_add_ps(y, \
348-
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p3)); \
349-
y = _mm256_mul_ps(y, x); \
350-
y = _mm256_add_ps(y, \
351-
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p4)); \
352-
y = _mm256_mul_ps(y, x); \
353-
y = _mm256_add_ps(y, \
354-
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p5)); \
355-
y = _mm256_mul_ps(y, z); \
356-
y = _mm256_add_ps(y, x); \
357-
y = _mm256_add_ps(y, one); \
358-
/* build 2^n */ \
359-
imm0 = _mm256_cvttps_epi32(fx)
360-
361-
__m256 ExpAVX(__m256 x) {
362-
AVXEXP_BASE;
363-
// two AVX2 instructions using SSE2
364-
imm0 = avx2_mm256_add_epi32(imm0,
365-
*reinterpret_cast<const __m256i*>(_pi256_0x7f));
366-
imm0 = avx2_mm256_slli_epi32(imm0, 23);
367-
__m256 pow2n = _mm256_castsi256_ps(imm0);
368-
y = _mm256_mul_ps(y, pow2n);
369-
return y;
370-
}
371-
#endif
372-
373-
#ifdef __AVX2__
374-
__m256 ExpAVX2(__m256 x) {
375-
AVXEXP_BASE;
376-
// two AVX2 instructions
377-
imm0 = _mm256_add_epi32(imm0, *reinterpret_cast<const __m256i*>(_pi256_0x7f));
378-
imm0 = _mm256_slli_epi32(imm0, 23);
379-
__m256 pow2n = _mm256_castsi256_ps(imm0);
380-
y = _mm256_mul_ps(y, pow2n);
381-
return y;
382-
}
383-
#endif
384-
385-
} // namespace detail
386234
} // namespace jitkernel
387235
} // namespace math
388236
} // namespace operators

paddle/fluid/operators/math/jit_kernel_impl.h

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,34 @@ typedef struct {
3838
void* checked{nullptr};
3939
} lstm_t;
4040

41-
typedef struct lstm_attr_s {
42-
bool use_peephole;
41+
typedef struct {
42+
void* gates; // gates: {W_update, W_reset; W_state}
43+
const void* ht_1;
44+
void* ht;
45+
} gru_t;
46+
47+
struct rnn_attr_s {
4348
int d;
44-
std::string act_gate, act_cand, act_cell;
49+
std::string act_gate, act_cand;
50+
rnn_attr_s() = default;
51+
rnn_attr_s(int _d, const std::string& _act_gate, const std::string& _act_cand)
52+
: d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
53+
};
54+
55+
struct lstm_attr_s : public rnn_attr_s {
56+
bool use_peephole;
57+
std::string act_cell;
4558
lstm_attr_s() = default;
4659
lstm_attr_s(int _d, const std::string& _act_gate,
4760
const std::string& _act_cand, const std::string& _act_cell,
4861
bool _use_peephole = false)
49-
: use_peephole(_use_peephole),
50-
d(_d),
51-
act_gate(_act_gate),
52-
act_cand(_act_cand),
62+
: rnn_attr_s(_d, _act_gate, _act_cand),
63+
use_peephole(_use_peephole),
5364
act_cell(_act_cell) {}
54-
} lstm_attr_t;
65+
};
66+
67+
typedef struct rnn_attr_s gru_attr_t;
68+
typedef struct lstm_attr_s lstm_attr_t;
5569

5670
} // namespace jitkernel
5771
} // namespace math

paddle/fluid/operators/math/jit_kernel_refer.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,46 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
185185
VMul(gates + d2, gates + d3, ht, d);
186186
}
187187

188+
// compute h1 without h0
189+
template <typename T>
190+
void GRUH1(gru_t* step, const gru_attr_t* attr) {
191+
T* gates = reinterpret_cast<T*>(step->gates);
192+
T* ht = reinterpret_cast<T*>(step->ht);
193+
auto act_gate = getActFunc<T>(attr->act_gate);
194+
auto act_cand = getActFunc<T>(attr->act_cand);
195+
int d = attr->d;
196+
int d2 = d * 2;
197+
act_gate(gates, gates, d);
198+
act_cand(gates + d2, gates + d2, d);
199+
VMul(gates, gates + d2, ht, d);
200+
}
201+
202+
template <typename T>
203+
void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
204+
// W: {W_update, W_reset; W_state}
205+
T* gates = reinterpret_cast<T*>(step->gates);
206+
T* ht = reinterpret_cast<T*>(step->ht);
207+
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
208+
auto act_gate = getActFunc<T>(attr->act_gate);
209+
act_gate(gates, gates, attr->d * 2);
210+
VMul(ht_1, gates + attr->d, ht, attr->d);
211+
}
212+
213+
template <typename T>
214+
void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
215+
T* gates = reinterpret_cast<T*>(step->gates);
216+
T* ht = reinterpret_cast<T*>(step->ht);
217+
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
218+
auto act_cand = getActFunc<T>(attr->act_cand);
219+
int d = attr->d;
220+
T* y = gates + d * 2;
221+
act_cand(y, y, d);
222+
// out = zt*ht~ + (1-zt)*ht_1
223+
for (int i = 0; i < d; ++i) {
224+
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
225+
}
226+
}
227+
188228
} // namespace refer
189229
} // namespace jitkernel
190230
} // namespace math

0 commit comments

Comments
 (0)