Skip to content

Commit 392d737

Browse files
committed
Merge pull request #16898 from Superjomn/fix/logical_op_infershape
fix logical op infershape
1 parent 21ea542 commit 392d737

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

paddle/fluid/operators/controlflow/logical_op.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,16 @@ class BinaryLogicalOpInferShape : public framework::InferShapeBase {
7171
"Input(Y) of %s operator must not be null", comment.type);
7272
auto dim_x = context->GetInputDim("X");
7373
auto dim_y = context->GetInputDim("Y");
74-
PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y),
75-
"The number of elements in X and Y should be same");
74+
75+
int product_x = framework::product(dim_x);
76+
int product_y = framework::product(dim_y);
77+
bool check = context->IsRuntime() || (product_x >= 0 && product_y >= 0);
78+
if (check) {
79+
PADDLE_ENFORCE_EQ(
80+
product_x, product_y,
81+
"The number of elements in X and Y should be same, %d != %d",
82+
product_x, product_y);
83+
}
7684

7785
context->SetOutputDim("Out", context->GetInputDim("X"));
7886
context->ShareLoD("X", "Out");

0 commit comments

Comments
 (0)