@@ -229,6 +229,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
229
229
auto * xx = ctx.Output <LoDTensor>(" XX" );
230
230
auto * hidden_out = ctx.Output <LoDTensor>(" Hidden" );
231
231
auto * cell_out = ctx.Output <LoDTensor>(" Cell" );
232
+ bool is_reverse = ctx.Attr <bool >(" is_reverse" );
232
233
233
234
std::function<void (const int , const T *, T *)> act_gate, act_cell, act_cand;
234
235
auto & act_gate_str = ctx.Attr <std::string>(" gate_activation" );
@@ -247,8 +248,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
247
248
}
248
249
249
250
auto x_lod = x->lod ();
250
- auto x_dims = x->dims (); // T x M
251
- auto wh_dims = wh->dims (); // D x 4D
251
+ auto x_dims = x->dims (); // T x M
252
+ auto wh_dims = wh->dims (); // D x 4D
253
+ const int total_T = x_dims[0 ];
252
254
const int N = x_lod[0 ].size () - 1 ; // batch size
253
255
const int M = x_dims[1 ]; // x frame size
254
256
const int D = wh_dims[0 ];
@@ -266,17 +268,34 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
266
268
T* cell_out_data = cell_out->mutable_data <T>(ctx.GetPlace ());
267
269
268
270
auto blas = math::GetBlas<DeviceContext, T>(ctx);
269
- math::FCCompute<DeviceContext, T>(blas, x_dims[ 0 ] , D4, M, x_data, wx_data,
271
+ math::FCCompute<DeviceContext, T>(blas, total_T , D4, M, x_data, wx_data,
270
272
xx_data, bias->data <T>());
273
+ int xx_offset = D4;
274
+ int gate_offset = D;
275
+ if (is_reverse) {
276
+ const int offset = (total_T - 1 ) * D;
277
+ xx_data = xx_data + offset * 4 ;
278
+ hidden_out_data = hidden_out_data + offset;
279
+ cell_out_data = cell_out_data + offset;
280
+ xx_offset = -D4;
281
+ gate_offset = -D;
282
+ }
283
+
284
+ auto move_step = [&]() {
285
+ xx_data = xx_data + xx_offset;
286
+ hidden_out_data = hidden_out_data + gate_offset;
287
+ cell_out_data = cell_out_data + gate_offset;
288
+ };
271
289
272
290
for (int i = 0 ; i < N; ++i) {
273
- int seq_len = x_lod[0 ][i + 1 ] - x_lod[0 ][i];
291
+ int bid = is_reverse ? N - 1 - i : i;
292
+ int seq_len = x_lod[0 ][bid + 1 ] - x_lod[0 ][bid];
274
293
const T* prev_cell_data = NULL ;
275
294
const T* prev_hidden_data = NULL ;
276
295
int tstart = 0 ;
277
296
if (h0_data) {
278
- prev_hidden_data = h0_data + i * D;
279
- prev_cell_data = c0_data + i * D;
297
+ prev_hidden_data = h0_data + bid * D;
298
+ prev_cell_data = c0_data + bid * D;
280
299
} else {
281
300
// W_ch, W_ih, W_fh, W_oh
282
301
act_gate (D3, xx_data + D, xx_data + D);
@@ -292,10 +311,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
292
311
prev_cell_data = cell_out_data;
293
312
tstart = 1 ;
294
313
295
- // move offset
296
- xx_data = xx_data + D4;
297
- hidden_out_data = hidden_out_data + D;
298
- cell_out_data = cell_out_data + D;
314
+ move_step ();
299
315
}
300
316
for (int step = tstart; step < seq_len; ++step) {
301
317
blas.GEMM (CblasNoTrans, CblasNoTrans, 1 , D4, D, static_cast <T>(1 ),
@@ -323,10 +339,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
323
339
prev_hidden_data = hidden_out_data;
324
340
prev_cell_data = cell_out_data;
325
341
326
- // move offset
327
- xx_data = xx_data + D4;
328
- hidden_out_data = hidden_out_data + D;
329
- cell_out_data = cell_out_data + D;
342
+ move_step ();
330
343
}
331
344
}
332
345
}
0 commit comments