Skip to content

Commit a7f3450

Browse files
committed
cherry-pick #16898 and #16902.
test=Release/1.4
1 parent 392d737 commit a7f3450

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

paddle/fluid/operators/gru_op.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,11 @@ class GRUOp : public framework::OperatorWithKernel {
4747
auto weight_dims = ctx->GetInputDim("Weight");
4848
int input_size = input_dims[1];
4949
int frame_size = weight_dims[0];
50-
PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
51-
"The input_size must be 3 times of frame_size in GRUOp.");
50+
if (ctx->IsRuntime()) {
51+
PADDLE_ENFORCE_EQ(
52+
input_size, frame_size * 3,
53+
"The input_size must be 3 times of frame_size in GRUOp.");
54+
}
5255
PADDLE_ENFORCE_EQ(
5356
weight_dims[1], frame_size * 3,
5457
"The shape of Weight matrix must be [frame_size, frame_size * 3].");

paddle/fluid/operators/lstm_unit_op.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ class LstmUnitOp : public framework::OperatorWithKernel {
3434
auto c_prev_dims = ctx->GetInputDim("C_prev");
3535

3636
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
37-
PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0],
38-
"Batch size of inputs and states must be equal");
39-
PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4,
40-
"Dimension of FC should equal to prev state * 4");
37+
if (ctx->IsRuntime()) {
38+
PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0],
39+
"Batch size of inputs and states must be equal");
40+
PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4,
41+
"Dimension of FC should equal to prev state * 4");
42+
}
4143

4244
int b_size = c_prev_dims[0]; // batch size
4345
int s_dim = c_prev_dims[1]; // state dim

0 commit comments

Comments
 (0)