Skip to content

Commit 30e47bc

Browse files
authored
Merge branch 'develop' into revert_vlog
2 parents be04d99 + 3ae6692 commit 30e47bc

File tree

15 files changed

+1100
-948
lines changed

15 files changed

+1100
-948
lines changed

Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ RUN wget -q https://www.python.org/ftp/python/3.7.0/Python-3.7.0.tgz && \
4343
CFLAGS="-Wformat" ./configure --prefix=/usr/local/ --enable-shared > /dev/null && \
4444
make -j8 > /dev/null && make altinstall > /dev/null
4545

46+
RUN rm -r /root/python_build
47+
4648
RUN apt-get update && \
4749
apt-get install -y --allow-downgrades patchelf \
4850
python3 python3-dev python3-pip \

paddle/fluid/operators/fused/fusion_gru_op.cc

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -183,24 +183,27 @@ class FusionGRUKernel : public framework::OpKernel<T> {
183183
const int total_T = x_dims[0]; \
184184
const int D3 = wh_dims[1]
185185

186-
#define INIT_OTHER_DEFINES \
187-
auto* h0 = ctx.Input<Tensor>("H0"); \
188-
auto* wx = ctx.Input<Tensor>("WeightX"); \
189-
auto* bias = ctx.Input<Tensor>("Bias"); \
190-
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
191-
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
192-
const int M = x_dims[1]; \
193-
const int D = wh_dims[0]; \
194-
const int D2 = D * 2; \
195-
const auto& ker = math::jitkernel::KernelPool::Instance() \
196-
.template Get<math::jitkernel::GRUKernel<T>, \
197-
const std::string&, const std::string&>( \
198-
ctx.Attr<std::string>("gate_activation"), \
199-
ctx.Attr<std::string>("activation"), D); \
200-
const T* x_data = x->data<T>(); \
201-
const T* wx_data = wx->data<T>(); \
202-
const T* wh_data = wh->data<T>(); \
203-
auto place = ctx.GetPlace(); \
186+
#define INIT_OTHER_DEFINES \
187+
auto* h0 = ctx.Input<Tensor>("H0"); \
188+
auto* wx = ctx.Input<Tensor>("WeightX"); \
189+
auto* bias = ctx.Input<Tensor>("Bias"); \
190+
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
191+
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
192+
const int M = x_dims[1]; \
193+
const int D = wh_dims[0]; \
194+
const int D2 = D * 2; \
195+
const math::jitkernel::gru_attr_t attr( \
196+
D, ctx.Attr<std::string>("gate_activation"), \
197+
ctx.Attr<std::string>("activation")); \
198+
math::jitkernel::gru_t one_step; \
199+
const auto& ker = \
200+
math::jitkernel::KernelPool::Instance() \
201+
.template Get<math::jitkernel::GRUKernel<T>, \
202+
const math::jitkernel::gru_attr_t&>(attr); \
203+
const T* x_data = x->data<T>(); \
204+
const T* wx_data = wx->data<T>(); \
205+
const T* wh_data = wh->data<T>(); \
206+
auto place = ctx.GetPlace(); \
204207
T* xx_data = xx->mutable_data<T>(place)
205208

206209
void SeqCompute(const framework::ExecutionContext& ctx) const {
@@ -237,7 +240,9 @@ class FusionGRUKernel : public framework::OpKernel<T> {
237240
if (h0_data) {
238241
prev_hidden_data = h0_data + bid * D;
239242
} else {
240-
ker->ComputeH1(xx_data, hidden_out_data);
243+
one_step.gates = xx_data;
244+
one_step.ht = hidden_out_data;
245+
ker->ComputeH1(&one_step, &attr);
241246
prev_hidden_data = hidden_out_data;
242247
tstart = 1;
243248
move_step();
@@ -247,12 +252,15 @@ class FusionGRUKernel : public framework::OpKernel<T> {
247252
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
248253
prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
249254
D3);
250-
ker->ComputeHtPart1(xx_data, prev_hidden_data, hidden_out_data);
255+
one_step.gates = xx_data;
256+
one_step.ht_1 = prev_hidden_data;
257+
one_step.ht = hidden_out_data;
258+
ker->ComputeHtPart1(&one_step, &attr);
251259
// gemm rt * Ws
252260
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
253261
hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
254262
xx_data + D2, D3);
255-
ker->ComputeHtPart2(xx_data, prev_hidden_data, hidden_out_data);
263+
ker->ComputeHtPart2(&one_step, &attr);
256264
// save prev
257265
prev_hidden_data = hidden_out_data;
258266
move_step();
@@ -314,7 +322,9 @@ class FusionGRUKernel : public framework::OpKernel<T> {
314322
T* cur_out_data = batched_out_data;
315323
// W: {W_update, W_reset; W_state}
316324
for (int i = 0; i < max_bs; ++i) {
317-
ker->ComputeH1(cur_in_data, cur_out_data);
325+
one_step.gates = cur_in_data;
326+
one_step.ht = cur_out_data;
327+
ker->ComputeH1(&one_step, &attr);
318328
// add offset
319329
cur_in_data += D3;
320330
cur_out_data += D;
@@ -339,8 +349,11 @@ class FusionGRUKernel : public framework::OpKernel<T> {
339349
T* cur_out_data = batched_out_data;
340350
T* cur_prev_hidden_data = prev_hidden_data;
341351
for (int i = 0; i < cur_bs; ++i) {
342-
ker->ComputeHtPart1(cur_batched_data, cur_prev_hidden_data,
343-
cur_out_data);
352+
one_step.gates = cur_batched_data;
353+
one_step.ht_1 = cur_prev_hidden_data;
354+
one_step.ht = cur_out_data;
355+
ker->ComputeHtPart1(&one_step, &attr);
356+
344357
cur_batched_data += D3;
345358
cur_prev_hidden_data += D;
346359
cur_out_data += D;
@@ -354,8 +367,10 @@ class FusionGRUKernel : public framework::OpKernel<T> {
354367

355368
cur_prev_hidden_data = prev_hidden_data;
356369
for (int i = 0; i < cur_bs; ++i) {
357-
ker->ComputeHtPart2(cur_batched_data, cur_prev_hidden_data,
358-
cur_out_data);
370+
one_step.gates = cur_batched_data;
371+
one_step.ht_1 = cur_prev_hidden_data;
372+
one_step.ht = cur_out_data;
373+
ker->ComputeHtPart2(&one_step, &attr);
359374
cur_batched_data += D3;
360375
cur_prev_hidden_data += D;
361376
cur_out_data += D;

paddle/fluid/operators/fused/fusion_lstm_op.cc

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -236,27 +236,31 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
236236
const int D = wh_dims[0]; \
237237
const int D4 = wh_dims[1]
238238

239-
#define INIT_OTHER_DEFINES \
240-
const T* x_data = x->data<T>(); \
241-
const T* wx_data = wx->data<T>(); \
242-
const T* wh_data = wh->data<T>(); \
243-
/* diagonal weight*/ \
244-
const T* wp_data = bias->data<T>() + D4; \
245-
/* for peephole only*/ \
246-
T* checked_cell_data = nullptr; \
247-
auto place = ctx.GetPlace(); \
248-
if (use_peepholes) { \
249-
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
250-
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
251-
checked_cell_data = checked_cell->mutable_data<T>(place); \
252-
} \
253-
const auto& ker = \
254-
math::jitkernel::KernelPool::Instance() \
255-
.template Get<math::jitkernel::LSTMKernel<T>, const std::string&, \
256-
const std::string&, const std::string&>( \
257-
ctx.Attr<std::string>("gate_activation"), \
258-
ctx.Attr<std::string>("candidate_activation"), \
259-
ctx.Attr<std::string>("cell_activation"), D, use_peepholes)
239+
#define INIT_OTHER_DEFINES \
240+
const T* x_data = x->data<T>(); \
241+
const T* wx_data = wx->data<T>(); \
242+
const T* wh_data = wh->data<T>(); \
243+
/* diagonal weight*/ \
244+
const T* wp_data = bias->data<T>() + D4; \
245+
/* for peephole only*/ \
246+
T* checked_cell_data = nullptr; \
247+
auto place = ctx.GetPlace(); \
248+
if (use_peepholes) { \
249+
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
250+
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
251+
checked_cell_data = checked_cell->mutable_data<T>(place); \
252+
} \
253+
const math::jitkernel::lstm_attr_t attr( \
254+
D, ctx.Attr<std::string>("gate_activation"), \
255+
ctx.Attr<std::string>("candidate_activation"), \
256+
ctx.Attr<std::string>("cell_activation"), use_peepholes); \
257+
math::jitkernel::lstm_t one_step; \
258+
one_step.wp = wp_data; \
259+
one_step.checked = checked_cell_data; \
260+
const auto& ker = \
261+
math::jitkernel::KernelPool::Instance() \
262+
.template Get<math::jitkernel::LSTMKernel<T>, \
263+
const math::jitkernel::lstm_attr_t&>(attr)
260264

261265
// Wh GEMM
262266
#define GEMM_WH_ADDON(bs, prev, out) \
@@ -299,7 +303,10 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
299303
prev_h_data = h0_data + bid * D;
300304
prev_c_data = c0_data + bid * D;
301305
} else {
302-
ker->ComputeC1H1(xx_data, c_out_data, h_out_data, wp_data);
306+
one_step.gates = xx_data;
307+
one_step.ct = c_out_data;
308+
one_step.ht = h_out_data;
309+
ker->ComputeC1H1(&one_step, &attr);
303310
tstart = 1;
304311
// move one step
305312
prev_h_data = h_out_data;
@@ -310,8 +317,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
310317
}
311318
for (int step = tstart; step < seq_len; ++step) {
312319
GEMM_WH_ADDON(1, prev_h_data, xx_data);
313-
ker->ComputeCtHt(xx_data, prev_c_data, c_out_data, h_out_data, wp_data,
314-
checked_cell_data);
320+
321+
one_step.gates = xx_data;
322+
one_step.ct_1 = prev_c_data;
323+
one_step.ct = c_out_data;
324+
one_step.ht = h_out_data;
325+
ker->ComputeCtHt(&one_step, &attr);
315326
// move one step
316327
prev_h_data = h_out_data;
317328
prev_c_data = c_out_data;
@@ -388,7 +399,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
388399
T* cur_h_out_data = batched_h_out_data;
389400
T* cur_c_out_data = batched_c_out_data;
390401
for (int i = 0; i < max_bs; ++i) {
391-
ker->ComputeC1H1(cur_in_data, cur_c_out_data, cur_h_out_data, wp_data);
402+
one_step.gates = cur_in_data;
403+
one_step.ct = cur_c_out_data;
404+
one_step.ht = cur_h_out_data;
405+
ker->ComputeC1H1(&one_step, &attr);
406+
392407
cur_in_data += D4;
393408
cur_c_out_data += D;
394409
cur_h_out_data += D;
@@ -413,8 +428,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
413428
T* cur_c_out_data = batched_c_out_data;
414429
T* cur_h_out_data = batched_h_out_data;
415430
for (int i = 0; i < cur_bs; ++i) {
416-
ker->ComputeCtHt(cur_in_data, cur_prev_c_data, cur_c_out_data,
417-
cur_h_out_data, wp_data, checked_cell_data);
431+
one_step.gates = cur_in_data;
432+
one_step.ct_1 = cur_prev_c_data;
433+
one_step.ct = cur_c_out_data;
434+
one_step.ht = cur_h_out_data;
435+
ker->ComputeCtHt(&one_step, &attr);
436+
418437
// move one batch
419438
cur_in_data += D4;
420439
cur_prev_c_data += D;

0 commit comments

Comments
 (0)