Skip to content

Commit 5ef14dd

Browse files
authored
Merge pull request #13715 from tensor-tang/fix/op
bugfix fusion lstm and gru batch,seq mode switch
2 parents c0dfd5e + ea0b98e commit 5ef14dd

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

paddle/fluid/operators/fusion_gru_op.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,13 @@ class FusionGRUKernel : public framework::OpKernel<T> {
290290
void BatchCompute(const framework::ExecutionContext& ctx) const {
291291
using DeviceContext = paddle::platform::CPUDeviceContext;
292292
auto* x = ctx.Input<LoDTensor>("X");
293+
INIT_BASE_INPUT_OUTPUT
294+
INIT_BASE_SIZES
293295
if (x->lod()[0].size() == 2) {
296+
xx->Resize({total_T, D3});
294297
SeqCompute(ctx);
295298
return;
296299
}
297-
INIT_BASE_INPUT_OUTPUT
298-
INIT_BASE_SIZES
299300
INIT_VEC_FUNC
300301

301302
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,11 +432,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
432432
void BatchCompute(const framework::ExecutionContext& ctx) const {
433433
using DeviceContext = platform::CPUDeviceContext;
434434
INIT_BASE_INPUT_OUTPUT
435+
INIT_BASE_SIZES
435436
if (x->lod()[0].size() == 2) {
437+
xx->Resize({x_dims[0], D4});
436438
SeqCompute(ctx);
437439
return;
438440
}
439-
INIT_BASE_SIZES
440441
INIT_VEC_FUNC
441442
INIT_BASE_INPUT_DATAS
442443

0 commit comments

Comments
 (0)