Skip to content

Commit 5a2d6d6

Browse files
authored
Merge pull request #16956 from guoshengCS/cherry-pick-infer-shape
cherry-pick #16898 and #16902 to release/1.4
2 parents aa056e9 + a7f3450 commit 5a2d6d6

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

paddle/fluid/operators/controlflow/logical_op.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,16 @@ class BinaryLogicalOpInferShape : public framework::InferShapeBase {
7171
"Input(Y) of %s operator must not be null", comment.type);
7272
auto dim_x = context->GetInputDim("X");
7373
auto dim_y = context->GetInputDim("Y");
74-
PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y),
75-
"The number of elements in X and Y should be same");
74+
75+
int product_x = framework::product(dim_x);
76+
int product_y = framework::product(dim_y);
77+
bool check = context->IsRuntime() || (product_x >= 0 && product_y >= 0);
78+
if (check) {
79+
PADDLE_ENFORCE_EQ(
80+
product_x, product_y,
81+
"The number of elements in X and Y should be same, %d != %d",
82+
product_x, product_y);
83+
}
7684

7785
context->SetOutputDim("Out", context->GetInputDim("X"));
7886
context->ShareLoD("X", "Out");

paddle/fluid/operators/gru_op.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,11 @@ class GRUOp : public framework::OperatorWithKernel {
4747
auto weight_dims = ctx->GetInputDim("Weight");
4848
int input_size = input_dims[1];
4949
int frame_size = weight_dims[0];
50-
PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
51-
"The input_size must be 3 times of frame_size in GRUOp.");
50+
if (ctx->IsRuntime()) {
51+
PADDLE_ENFORCE_EQ(
52+
input_size, frame_size * 3,
53+
"The input_size must be 3 times of frame_size in GRUOp.");
54+
}
5255
PADDLE_ENFORCE_EQ(
5356
weight_dims[1], frame_size * 3,
5457
"The shape of Weight matrix must be [frame_size, frame_size * 3].");

paddle/fluid/operators/lstm_unit_op.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ class LstmUnitOp : public framework::OperatorWithKernel {
3434
auto c_prev_dims = ctx->GetInputDim("C_prev");
3535

3636
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
37-
PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0],
38-
"Batch size of inputs and states must be equal");
39-
PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4,
40-
"Dimension of FC should equal to prev state * 4");
37+
if (ctx->IsRuntime()) {
38+
PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0],
39+
"Batch size of inputs and states must be equal");
40+
PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4,
41+
"Dimension of FC should equal to prev state * 4");
42+
}
4143

4244
int b_size = c_prev_dims[0]; // batch size
4345
int s_dim = c_prev_dims[1]; // state dim

0 commit comments

Comments
 (0)