Skip to content

Commit d7ac1cc

Browse files
committed
refine seq when bs is large
1 parent 9dd5a17 commit d7ac1cc

File tree

2 files changed

+59
-30
lines changed

2 files changed

+59
-30
lines changed

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -285,18 +285,23 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
285285
act_cell(D, ct, gates + D2); \
286286
blas.VMUL(D, gates + D2, gates + D3, ht)
287287

288-
#define COMPUTE_CtHt_WITHOUT_H0C0(gates, ct, ht) \
289-
act_gate(D, gates + D, gates + D); \
290-
act_cand(D, gates, gates); \
291-
/* C_t = igated * cgated*/ \
292-
blas.VMUL(D, gates, gates + D, ct); \
293-
/* get outgated*/ \
294-
if (use_peepholes) { \
295-
/* put W_oc * C_t on igated */ \
296-
blas.VMUL(D, wc_data + D2, ct, gates + D); \
297-
blas.VADD(D, gates + D, gates + D3, gates + D3); \
298-
} \
299-
act_gate(D, gates + D3, gates + D3); \
288+
#define GET_Ct_NOH0C0(gates, ct) \
289+
/* C_t = igated * cgated*/ \
290+
act_gate(D, gates + D, gates + D); \
291+
act_cand(D, gates, gates); \
292+
blas.VMUL(D, gates, gates + D, ct)
293+
294+
#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \
295+
GET_Ct_NOH0C0(gates, ct); \
296+
act_gate(D, gates + D3, gates + D3); \
297+
GET_Ht(ct, gates, ht)
298+
299+
#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \
300+
GET_Ct_NOH0C0(gates, ct); \
301+
/* get outgated, put W_oc * C_t on igated */ \
302+
blas.VMUL(D, wc_data + D2, ct, gates + D); \
303+
blas.VADD(D, gates + D, gates + D3, gates + D3); \
304+
act_gate(D, gates + D3, gates + D3); \
300305
GET_Ht(ct, gates, ht)
301306

302307
#define COMPUTE_CtHt(gates, ct_1, ct, ht) \
@@ -354,24 +359,38 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
354359
h_out_data = h_out_data + gate_offset; \
355360
c_out_data = c_out_data + gate_offset
356361

357-
#define PROCESS_H0C0 \
358-
int bid = is_reverse ? N - 1 - i : i; \
359-
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \
360-
const T* prev_c_data = nullptr; \
361-
const T* prev_h_data = nullptr; \
362-
int tstart = 0; \
363-
if (h0_data) { \
364-
prev_h_data = h0_data + bid * D; \
365-
prev_c_data = c0_data + bid * D; \
366-
} else { \
367-
COMPUTE_CtHt_WITHOUT_H0C0(xx_data, c_out_data, h_out_data); \
368-
MOVE_ONE_STEP; \
369-
tstart = 1; \
362+
#define PROCESS_H0C0_DEFINES \
363+
int bid = is_reverse ? N - 1 - i : i; \
364+
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \
365+
const T* prev_c_data = nullptr; \
366+
const T* prev_h_data = nullptr; \
367+
int tstart = 0
368+
369+
#define PROCESS_H0C0_PEEPHOLE \
370+
PROCESS_H0C0_DEFINES; \
371+
if (h0_data) { \
372+
prev_h_data = h0_data + bid * D; \
373+
prev_c_data = c0_data + bid * D; \
374+
} else { \
375+
COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \
376+
MOVE_ONE_STEP; \
377+
tstart = 1; \
378+
}
379+
380+
#define PROCESS_H0C0 \
381+
PROCESS_H0C0_DEFINES; \
382+
if (h0_data) { \
383+
prev_h_data = h0_data + bid * D; \
384+
prev_c_data = c0_data + bid * D; \
385+
} else { \
386+
COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \
387+
MOVE_ONE_STEP; \
388+
tstart = 1; \
370389
}
371390

372391
if (use_peepholes) {
373392
for (int i = 0; i < N; ++i) {
374-
PROCESS_H0C0;
393+
PROCESS_H0C0_PEEPHOLE
375394
for (int step = tstart; step < seq_len; ++step) {
376395
GEMM_WH_ADDON(1, prev_h_data, xx_data);
377396
COMPUTE_CtHt_PEEPHOLE(xx_data, prev_c_data, c_out_data, h_out_data);
@@ -380,14 +399,16 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
380399
}
381400
} else {
382401
for (int i = 0; i < N; ++i) {
383-
PROCESS_H0C0;
402+
PROCESS_H0C0
384403
for (int step = tstart; step < seq_len; ++step) {
385404
GEMM_WH_ADDON(1, prev_h_data, xx_data);
386405
COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data);
387406
MOVE_ONE_STEP;
388407
}
389408
}
390409
}
410+
#undef PROCESS_H0C0_DEFINES
411+
#undef PROCESS_H0C0_PEEPHOLE
391412
#undef PROCESS_H0C0
392413
#undef MOVE_ONE_STEP
393414
}
@@ -460,7 +481,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
460481
T* cur_h_out_data = batched_h_out_data;
461482
T* cur_c_out_data = batched_c_out_data;
462483
for (int i = 0; i < max_bs; ++i) {
463-
COMPUTE_CtHt_WITHOUT_H0C0(cur_in_data, cur_c_out_data, cur_h_out_data);
484+
GET_Ct_NOH0C0(cur_in_data, cur_c_out_data);
485+
if (use_peepholes) {
486+
blas.VMUL(D, wc_data + D2, cur_c_out_data, cur_in_data + D);
487+
blas.VADD(D, cur_in_data + D, cur_in_data + D3, cur_in_data + D3);
488+
}
489+
act_gate(D, cur_in_data + D3, cur_in_data + D3);
490+
GET_Ht(cur_c_out_data, cur_in_data, cur_h_out_data);
464491
cur_in_data += D4;
465492
cur_c_out_data += D;
466493
cur_h_out_data += D;
@@ -541,7 +568,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
541568

542569
#undef COMPUTE_CtHt_PEEPHOLE
543570
#undef COMPUTE_CtHt
544-
#undef COMPUTE_CtHt_WITHOUT_H0C0
571+
#undef GET_Ct_NOH0C0
572+
#undef COMPUTE_CtHt_NOH0C0
573+
#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0
545574
#undef GET_Ht
546575
#undef GET_Ct
547576
#undef GEMM_WH_ADDON

python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def set_conf(self):
183183
self.is_reverse = True
184184

185185

186-
class TestFusionLSTMOpPoopholesBS1(TestFusionLSTMOp):
186+
class TestFusionLSTMOpPeepholesBS1(TestFusionLSTMOp):
187187
def set_conf(self):
188188
self.use_peepholes = True
189189
self.lod = [[2]]

0 commit comments

Comments
 (0)