@@ -285,18 +285,23 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
285
285
act_cell (D, ct, gates + D2); \
286
286
blas.VMUL(D, gates + D2, gates + D3, ht)
287
287
288
- #define COMPUTE_CtHt_WITHOUT_H0C0 (gates, ct, ht ) \
289
- act_gate (D, gates + D, gates + D); \
290
- act_cand (D, gates, gates); \
291
- /* C_t = igated * cgated*/ \
292
- blas.VMUL(D, gates, gates + D, ct); \
293
- /* get outgated*/ \
294
- if (use_peepholes) { \
295
- /* put W_oc * C_t on igated */ \
296
- blas.VMUL (D, wc_data + D2, ct, gates + D); \
297
- blas.VADD (D, gates + D, gates + D3, gates + D3); \
298
- } \
299
- act_gate (D, gates + D3, gates + D3); \
288
+ #define GET_Ct_NOH0C0 (gates, ct ) \
289
+ /* C_t = igated * cgated*/ \
290
+ act_gate (D, gates + D, gates + D); \
291
+ act_cand (D, gates, gates); \
292
+ blas.VMUL(D, gates, gates + D, ct)
293
+
294
+ #define COMPUTE_CtHt_NOH0C0 (gates, ct, ht ) \
295
+ GET_Ct_NOH0C0 (gates, ct); \
296
+ act_gate (D, gates + D3, gates + D3); \
297
+ GET_Ht (ct, gates, ht)
298
+
299
+ #define COMPUTE_CtHt_PEEPHOLE_NOH0C0 (gates, ct, ht ) \
300
+ GET_Ct_NOH0C0 (gates, ct); \
301
+ /* get outgated, put W_oc * C_t on igated */ \
302
+ blas.VMUL(D, wc_data + D2, ct, gates + D); \
303
+ blas.VADD(D, gates + D, gates + D3, gates + D3); \
304
+ act_gate (D, gates + D3, gates + D3); \
300
305
GET_Ht (ct, gates, ht)
301
306
302
307
#define COMPUTE_CtHt (gates, ct_1, ct, ht ) \
@@ -354,24 +359,38 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
354
359
h_out_data = h_out_data + gate_offset; \
355
360
c_out_data = c_out_data + gate_offset
356
361
357
- #define PROCESS_H0C0 \
358
- int bid = is_reverse ? N - 1 - i : i; \
359
- int seq_len = x_lod[0 ][bid + 1 ] - x_lod[0 ][bid]; \
360
- const T* prev_c_data = nullptr ; \
361
- const T* prev_h_data = nullptr ; \
362
- int tstart = 0 ; \
363
- if (h0_data) { \
364
- prev_h_data = h0_data + bid * D; \
365
- prev_c_data = c0_data + bid * D; \
366
- } else { \
367
- COMPUTE_CtHt_WITHOUT_H0C0 (xx_data, c_out_data, h_out_data); \
368
- MOVE_ONE_STEP; \
369
- tstart = 1 ; \
362
+ #define PROCESS_H0C0_DEFINES \
363
+ int bid = is_reverse ? N - 1 - i : i; \
364
+ int seq_len = x_lod[0 ][bid + 1 ] - x_lod[0 ][bid]; \
365
+ const T* prev_c_data = nullptr ; \
366
+ const T* prev_h_data = nullptr ; \
367
+ int tstart = 0
368
+
369
+ #define PROCESS_H0C0_PEEPHOLE \
370
+ PROCESS_H0C0_DEFINES; \
371
+ if (h0_data) { \
372
+ prev_h_data = h0_data + bid * D; \
373
+ prev_c_data = c0_data + bid * D; \
374
+ } else { \
375
+ COMPUTE_CtHt_PEEPHOLE_NOH0C0 (xx_data, c_out_data, h_out_data); \
376
+ MOVE_ONE_STEP; \
377
+ tstart = 1 ; \
378
+ }
379
+
380
+ #define PROCESS_H0C0 \
381
+ PROCESS_H0C0_DEFINES; \
382
+ if (h0_data) { \
383
+ prev_h_data = h0_data + bid * D; \
384
+ prev_c_data = c0_data + bid * D; \
385
+ } else { \
386
+ COMPUTE_CtHt_NOH0C0 (xx_data, c_out_data, h_out_data); \
387
+ MOVE_ONE_STEP; \
388
+ tstart = 1 ; \
370
389
}
371
390
372
391
if (use_peepholes) {
373
392
for (int i = 0 ; i < N; ++i) {
374
- PROCESS_H0C0;
393
+ PROCESS_H0C0_PEEPHOLE
375
394
for (int step = tstart; step < seq_len; ++step) {
376
395
GEMM_WH_ADDON (1 , prev_h_data, xx_data);
377
396
COMPUTE_CtHt_PEEPHOLE (xx_data, prev_c_data, c_out_data, h_out_data);
@@ -380,14 +399,16 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
380
399
}
381
400
} else {
382
401
for (int i = 0 ; i < N; ++i) {
383
- PROCESS_H0C0;
402
+ PROCESS_H0C0
384
403
for (int step = tstart; step < seq_len; ++step) {
385
404
GEMM_WH_ADDON (1 , prev_h_data, xx_data);
386
405
COMPUTE_CtHt (xx_data, prev_c_data, c_out_data, h_out_data);
387
406
MOVE_ONE_STEP;
388
407
}
389
408
}
390
409
}
410
+ #undef PROCESS_H0C0_DEFINES
411
+ #undef PROCESS_H0C0_PEEPHOLE
391
412
#undef PROCESS_H0C0
392
413
#undef MOVE_ONE_STEP
393
414
}
@@ -460,7 +481,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
460
481
T* cur_h_out_data = batched_h_out_data;
461
482
T* cur_c_out_data = batched_c_out_data;
462
483
for (int i = 0 ; i < max_bs; ++i) {
463
- COMPUTE_CtHt_WITHOUT_H0C0 (cur_in_data, cur_c_out_data, cur_h_out_data);
484
+ GET_Ct_NOH0C0 (cur_in_data, cur_c_out_data);
485
+ if (use_peepholes) {
486
+ blas.VMUL (D, wc_data + D2, cur_c_out_data, cur_in_data + D);
487
+ blas.VADD (D, cur_in_data + D, cur_in_data + D3, cur_in_data + D3);
488
+ }
489
+ act_gate (D, cur_in_data + D3, cur_in_data + D3);
490
+ GET_Ht (cur_c_out_data, cur_in_data, cur_h_out_data);
464
491
cur_in_data += D4;
465
492
cur_c_out_data += D;
466
493
cur_h_out_data += D;
@@ -541,7 +568,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
541
568
542
569
#undef COMPUTE_CtHt_PEEPHOLE
543
570
#undef COMPUTE_CtHt
544
- #undef COMPUTE_CtHt_WITHOUT_H0C0
571
+ #undef GET_Ct_NOH0C0
572
+ #undef COMPUTE_CtHt_NOH0C0
573
+ #undef COMPUTE_CtHt_PEEPHOLE_NOH0C0
545
574
#undef GET_Ht
546
575
#undef GET_Ct
547
576
#undef GEMM_WH_ADDON
0 commit comments