@@ -76,12 +76,18 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
76
76
PADDLE_ENFORCE_EQ (b_dims.size (), 2 , " The rank of Input(Bias) should be 2." );
77
77
PADDLE_ENFORCE_EQ (b_dims[0 ], 1 ,
78
78
" The first dimension of Input(Bias) should be 1." );
79
- PADDLE_ENFORCE_EQ (
80
- b_dims[1 ], (ctx->Attrs ().Get <bool >(" use_peepholes" ) ? 7 : 4 ) * frame_size,
81
- " The second dimension of Input(Bias) should be "
82
- " 7 * %d if enable peepholes connection or"
83
- " 4 * %d if disable peepholes" ,
84
- frame_size, frame_size);
79
+ if (ctx->Attrs ().Get <bool >(" use_peepholes" )) {
80
+ PADDLE_ENFORCE_EQ (b_dims[1 ], 7 * frame_size,
81
+ " The second dimension of Input(Bias) should be "
82
+ " 7 * %d if enable peepholes connection" ,
83
+ frame_size);
84
+ ctx->SetOutputDim (" CheckedCell" , {2 , frame_size});
85
+ } else {
86
+ PADDLE_ENFORCE_EQ (b_dims[1 ], 4 * frame_size,
87
+ " The second dimension of Input(Bias) should be "
88
+ " 4 * %d if disable peepholes" ,
89
+ frame_size);
90
+ }
85
91
86
92
framework::DDim out_dims ({x_dims[0 ], frame_size});
87
93
ctx->SetOutputDim (" Hidden" , out_dims);
@@ -173,6 +179,8 @@ void FusionLSTMOpMaker::Make() {
173
179
AddOutput (" BatchedCell" , " (LoDTensor) (T x D)." ).AsIntermediate ();
174
180
AddOutput (" ReorderedH0" , " (LoDTensor) (N x D)." ).AsIntermediate ();
175
181
AddOutput (" ReorderedC0" , " (LoDTensor) (N x D)." ).AsIntermediate ();
182
+ AddOutput (" CheckedCell" , " (Tensor) (2 x D) only for peephole." )
183
+ .AsIntermediate ();
176
184
AddAttr<bool >(" use_peepholes" ,
177
185
" (bool, defalut: True) "
178
186
" whether to enable diagonal/peephole connections." )
@@ -250,19 +258,19 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
250
258
const int D3 = D * 3 ; \
251
259
const int D4 = wh_dims[1 ];
252
260
253
- #define INIT_BASE_INPUT_DATAS \
254
- const T* x_data = x->data<T>(); \
255
- const T* wx_data = wx->data<T>(); \
256
- const T* wh_data = wh->data<T>(); \
257
- /* diagonal weight*/ \
258
- const T* wc_data = bias->data<T>() + D4; \
259
- /* for peephole only*/ \
260
- Tensor checked_cell; \
261
- T* checked_cell_data = nullptr ; \
262
- auto place = ctx.GetPlace(); \
263
- if (use_peepholes) { \
264
- /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih */ \
265
- checked_cell_data = checked_cell. mutable_data <T>({ 2 , D}, place); \
261
+ #define INIT_BASE_INPUT_DATAS \
262
+ const T* x_data = x->data<T>(); \
263
+ const T* wx_data = wx->data<T>(); \
264
+ const T* wh_data = wh->data<T>(); \
265
+ /* diagonal weight*/ \
266
+ const T* wc_data = bias->data<T>() + D4; \
267
+ /* for peephole only*/ \
268
+ T* checked_cell_data = nullptr ; \
269
+ auto place = ctx.GetPlace(); \
270
+ if (use_peepholes) { \
271
+ /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih */ \
272
+ auto * checked_cell = ctx. Output <Tensor>( " CheckedCell " ); \
273
+ checked_cell_data = checked_cell-> mutable_data <T>(place); \
266
274
}
267
275
268
276
// / Compute LSTM
0 commit comments