Skip to content

Commit 9de6d88

Browse files
authored
Merge pull request #13722 from tensor-tang/fix/release/1.0.0
cherry-pick 'bugfix: fusion lstm and gru batch,seq mode switch'
2 parents e85d636 + f613797 commit 9de6d88

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

paddle/fluid/operators/fusion_gru_op.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,13 @@ class FusionGRUKernel : public framework::OpKernel<T> {
290290
void BatchCompute(const framework::ExecutionContext& ctx) const {
291291
using DeviceContext = paddle::platform::CPUDeviceContext;
292292
auto* x = ctx.Input<LoDTensor>("X");
293+
INIT_BASE_INPUT_OUTPUT
294+
INIT_BASE_SIZES
293295
if (x->lod()[0].size() == 2) {
296+
xx->Resize({total_T, D3});
294297
SeqCompute(ctx);
295298
return;
296299
}
297-
INIT_BASE_INPUT_OUTPUT
298-
INIT_BASE_SIZES
299300
INIT_VEC_FUNC
300301

301302
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,11 +424,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
424424
void BatchCompute(const framework::ExecutionContext& ctx) const {
425425
using DeviceContext = platform::CPUDeviceContext;
426426
INIT_BASE_INPUT_OUTPUT
427+
INIT_BASE_SIZES
427428
if (x->lod()[0].size() == 2) {
429+
xx->Resize({x_dims[0], D4});
428430
SeqCompute(ctx);
429431
return;
430432
}
431-
INIT_BASE_SIZES
432433
INIT_VEC_FUNC
433434
INIT_BASE_INPUT_DATAS
434435

0 commit comments

Comments
 (0)