@@ -220,24 +220,105 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
220
220
void SeqCompute (const framework::ExecutionContext& ctx) const {
221
221
using DeviceContext = paddle::platform::CPUDeviceContext;
222
222
auto * x = ctx.Input <LoDTensor>(" X" );
223
+ auto * h0 = ctx.Input <Tensor>(" H0" );
224
+ auto * c0 = ctx.Input <Tensor>(" C0" );
223
225
auto * wx = ctx.Input <Tensor>(" WeightX" );
224
226
auto * wh = ctx.Input <Tensor>(" WeightH" );
225
227
auto * bias = ctx.Input <Tensor>(" Bias" );
226
228
227
229
auto * xx = ctx.Output <LoDTensor>(" XX" );
230
+ auto * hidden_out = ctx.Output <LoDTensor>(" Hidden" );
231
+ auto * cell_out = ctx.Output <LoDTensor>(" Cell" );
228
232
229
- auto x_dims = x->dims (); // T x M
230
- auto wh_dims = wh->dims (); // D x 4D
231
- const int M = x_dims[1 ]; // x frame size
233
+ auto x_lod = x->lod ();
234
+ auto x_dims = x->dims (); // T x M
235
+ auto wh_dims = wh->dims (); // D x 4D
236
+ const int N = x_lod[0 ].size () - 1 ; // batch size
237
+ const int M = x_dims[1 ]; // x frame size
238
+ const int D = wh_dims[0 ];
239
+ const int D2 = D * 2 ;
240
+ const int D3 = D * 3 ;
232
241
const int D4 = wh_dims[1 ];
233
242
234
243
const T* x_data = x->data <T>();
244
+ const T* h0_data = h0 ? h0->data <T>() : NULL ;
245
+ const T* c0_data = c0 ? c0->data <T>() : NULL ;
235
246
const T* wx_data = wx->data <T>();
247
+ const T* wh_data = wh->data <T>();
236
248
T* xx_data = xx->mutable_data <T>(ctx.GetPlace ());
249
+ T* hidden_out_data = hidden_out->mutable_data <T>(ctx.GetPlace ());
250
+ T* cell_out_data = cell_out->mutable_data <T>(ctx.GetPlace ());
237
251
238
252
auto blas = math::GetBlas<DeviceContext, T>(ctx);
239
253
math::FCCompute<DeviceContext, T>(blas, x_dims[0 ], D4, M, x_data, wx_data,
240
254
xx_data, bias->data <T>());
255
+
256
+ for (int i = 0 ; i < N; ++i) {
257
+ int seq_len = x_lod[0 ][i + 1 ] - x_lod[0 ][i];
258
+ const T* prev_cell_data = NULL ;
259
+ const T* prev_hidden_data = NULL ;
260
+ int tstart = 0 ;
261
+ if (h0_data) {
262
+ prev_hidden_data = h0_data + i * D;
263
+ prev_cell_data = c0_data + i * D;
264
+ } else {
265
+ // W_ch, W_ih, W_fh, W_oh
266
+ // actgate
267
+ math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D);
268
+ // ch gate
269
+ math::vec_tanh<T>(D, xx_data, xx_data);
270
+ // cell out= input*tilde
271
+ blas.VMUL (D, xx_data, xx_data + D, cell_out_data);
272
+ // hidden out= act_state(cellout) * outgate
273
+ // act state
274
+ math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
275
+ blas.VMUL (D, xx_data + D2, xx_data + D3, hidden_out_data);
276
+
277
+ // prev
278
+ prev_hidden_data = hidden_out_data;
279
+ prev_cell_data = cell_out_data;
280
+ tstart = 1 ;
281
+
282
+ // move offset
283
+ xx_data = xx_data + D4;
284
+ hidden_out_data = hidden_out_data + D;
285
+ cell_out_data = cell_out_data + D;
286
+ }
287
+ for (int step = tstart; step < seq_len; ++step) {
288
+ blas.GEMM (CblasNoTrans, CblasNoTrans, 1 , D4, D, static_cast <T>(1 ),
289
+ prev_hidden_data, D, wh_data, D4, static_cast <T>(1 ), xx_data,
290
+ D4);
291
+
292
+ // W_ch, W_ih, W_fh, W_oh
293
+ // actgate
294
+ math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D);
295
+ // ch gate
296
+ math::vec_tanh<T>(D, xx_data, xx_data);
297
+
298
+ // a = forget * prev_cell
299
+ blas.VMUL (D, xx_data + D2, prev_cell_data, xx_data + D2);
300
+
301
+ // b = input * tilde
302
+ blas.VMUL (D, xx_data, xx_data + D, xx_data + D);
303
+
304
+ // cell out= a+b
305
+ blas.VADD (D, xx_data + D, xx_data + D2, cell_out_data);
306
+
307
+ // hidden out= act_state(cellout) * outgate
308
+ // act state
309
+ math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
310
+ blas.VMUL (D, xx_data + D2, xx_data + D3, hidden_out_data);
311
+
312
+ // prev
313
+ prev_hidden_data = hidden_out_data;
314
+ prev_cell_data = cell_out_data;
315
+
316
+ // move offset
317
+ xx_data = xx_data + D4;
318
+ hidden_out_data = hidden_out_data + D;
319
+ cell_out_data = cell_out_data + D;
320
+ }
321
+ }
241
322
}
242
323
243
324
void BatchCompute (const framework::ExecutionContext& ctx) const {
0 commit comments