Skip to content

Commit 0cc984b

Browse files
luotao1phlrain
authored andcommitted
Merge pull request #16852 from sneaxiy/fix_merge_lod_tensor_op_infer_shape
Fix merge_lod_tensor_op infer shape
1 parent 87be315 commit 0cc984b

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

paddle/fluid/operators/merge_lod_tensor_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ class MergeLoDTensorInferShape : public framework::InferShapeBase {
164164

165165
auto mask_dim = context->GetInputDim("Mask");
166166
PADDLE_ENFORCE_EQ(mask_dim.size(), 2);
167-
PADDLE_ENFORCE_EQ(mask_dim[1], 1);
167+
if (context->IsRuntime() || mask_dim[1] > 0) {
168+
PADDLE_ENFORCE_EQ(mask_dim[1], 1);
169+
}
168170

169171
context->SetOutputDim("Out", context->GetInputDim("InTrue"));
170172
}

0 commit comments

Comments
 (0)