Skip to content

Commit 9b30c51

Browse files
committed
Merge pull request #16861 from tensor-tang/refine/infershape
separate runtime infershape
1 parent 0cc984b commit 9b30c51

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

paddle/fluid/operators/attention_lstm_op.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,19 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
6464

6565
auto c_dims = ctx->GetInputDim("C0");
6666
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
67-
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
67+
if (ctx->IsRuntime()) {
68+
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
69+
}
70+
6871
if (ctx->HasInput("H0")) {
6972
auto h_dims = ctx->GetInputDim("H0");
70-
PADDLE_ENFORCE(h_dims == c_dims,
71-
"The dimension of Input(H0) and Input(C0) "
72-
"should be the same.");
73+
PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, "Input(H0)'s rank must be 2.");
74+
if (ctx->IsRuntime() ||
75+
(framework::product(c_dims) > 0 && framework::product(h_dims) > 0)) {
76+
PADDLE_ENFORCE(h_dims == c_dims,
77+
"The dimension of Input(H0) and Input(C0) "
78+
"should be the same.");
79+
}
7380
}
7481

7582
auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
@@ -79,6 +86,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
7986
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
8087
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
8188
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
89+
8290
if (ctx->HasInput("AttentionBias")) {
8391
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
8492
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,

0 commit comments

Comments
 (0)