Skip to content

Commit 607c419

Browse files
committed
compute gates
1 parent 6be273c commit 607c419

File tree

1 file changed

+84
-3
lines changed

1 file changed

+84
-3
lines changed

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,24 +220,105 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
220220
void SeqCompute(const framework::ExecutionContext& ctx) const {
221221
using DeviceContext = paddle::platform::CPUDeviceContext;
222222
auto* x = ctx.Input<LoDTensor>("X");
223+
auto* h0 = ctx.Input<Tensor>("H0");
224+
auto* c0 = ctx.Input<Tensor>("C0");
223225
auto* wx = ctx.Input<Tensor>("WeightX");
224226
auto* wh = ctx.Input<Tensor>("WeightH");
225227
auto* bias = ctx.Input<Tensor>("Bias");
226228

227229
auto* xx = ctx.Output<LoDTensor>("XX");
230+
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
231+
auto* cell_out = ctx.Output<LoDTensor>("Cell");
228232

229-
auto x_dims = x->dims(); // T x M
230-
auto wh_dims = wh->dims(); // D x 4D
231-
const int M = x_dims[1]; // x frame size
233+
auto x_lod = x->lod();
234+
auto x_dims = x->dims(); // T x M
235+
auto wh_dims = wh->dims(); // D x 4D
236+
const int N = x_lod[0].size() - 1; // batch size
237+
const int M = x_dims[1]; // x frame size
238+
const int D = wh_dims[0];
239+
const int D2 = D * 2;
240+
const int D3 = D * 3;
232241
const int D4 = wh_dims[1];
233242

234243
const T* x_data = x->data<T>();
244+
const T* h0_data = h0 ? h0->data<T>() : NULL;
245+
const T* c0_data = c0 ? c0->data<T>() : NULL;
235246
const T* wx_data = wx->data<T>();
247+
const T* wh_data = wh->data<T>();
236248
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
249+
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
250+
T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
237251

238252
auto blas = math::GetBlas<DeviceContext, T>(ctx);
239253
math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, x_data, wx_data,
240254
xx_data, bias->data<T>());
255+
256+
for (int i = 0; i < N; ++i) {
257+
int seq_len = x_lod[0][i + 1] - x_lod[0][i];
258+
const T* prev_cell_data = NULL;
259+
const T* prev_hidden_data = NULL;
260+
int tstart = 0;
261+
if (h0_data) {
262+
prev_hidden_data = h0_data + i * D;
263+
prev_cell_data = c0_data + i * D;
264+
} else {
265+
// W_ch, W_ih, W_fh, W_oh
266+
// actgate
267+
math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D);
268+
// ch gate
269+
math::vec_tanh<T>(D, xx_data, xx_data);
270+
// cell out= input*tilde
271+
blas.VMUL(D, xx_data, xx_data + D, cell_out_data);
272+
// hidden out= act_state(cellout) * outgate
273+
// act state
274+
math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
275+
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
276+
277+
// prev
278+
prev_hidden_data = hidden_out_data;
279+
prev_cell_data = cell_out_data;
280+
tstart = 1;
281+
282+
// move offset
283+
xx_data = xx_data + D4;
284+
hidden_out_data = hidden_out_data + D;
285+
cell_out_data = cell_out_data + D;
286+
}
287+
for (int step = tstart; step < seq_len; ++step) {
288+
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1),
289+
prev_hidden_data, D, wh_data, D4, static_cast<T>(1), xx_data,
290+
D4);
291+
292+
// W_ch, W_ih, W_fh, W_oh
293+
// actgate
294+
math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D);
295+
// ch gate
296+
math::vec_tanh<T>(D, xx_data, xx_data);
297+
298+
// a = forget * prev_cell
299+
blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2);
300+
301+
// b = input * tilde
302+
blas.VMUL(D, xx_data, xx_data + D, xx_data + D);
303+
304+
// cell out= a+b
305+
blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data);
306+
307+
// hidden out= act_state(cellout) * outgate
308+
// act state
309+
math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
310+
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
311+
312+
// prev
313+
prev_hidden_data = hidden_out_data;
314+
prev_cell_data = cell_out_data;
315+
316+
// move offset
317+
xx_data = xx_data + D4;
318+
hidden_out_data = hidden_out_data + D;
319+
cell_out_data = cell_out_data + D;
320+
}
321+
}
241322
}
242323

243324
void BatchCompute(const framework::ExecutionContext& ctx) const {

0 commit comments

Comments
 (0)