@@ -64,12 +64,19 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
64
64
65
65
auto c_dims = ctx->GetInputDim (" C0" );
66
66
PADDLE_ENFORCE_EQ (c_dims.size (), 2 , " Input(C0)'s rank must be 2." );
67
- PADDLE_ENFORCE_EQ (c_dims[1 ], D, " C0 dims should be N x %d." , D);
67
+ if (ctx->IsRuntime ()) {
68
+ PADDLE_ENFORCE_EQ (c_dims[1 ], D, " C0 dims should be N x %d." , D);
69
+ }
70
+
68
71
if (ctx->HasInput (" H0" )) {
69
72
auto h_dims = ctx->GetInputDim (" H0" );
70
- PADDLE_ENFORCE (h_dims == c_dims,
71
- " The dimension of Input(H0) and Input(C0) "
72
- " should be the same." );
73
+ PADDLE_ENFORCE_EQ (h_dims.size (), 2UL , " Input(H0)'s rank must be 2." );
74
+ if (ctx->IsRuntime () ||
75
+ (framework::product (c_dims) > 0 && framework::product (h_dims) > 0 )) {
76
+ PADDLE_ENFORCE (h_dims == c_dims,
77
+ " The dimension of Input(H0) and Input(C0) "
78
+ " should be the same." );
79
+ }
73
80
}
74
81
75
82
auto atten_w_dims = ctx->GetInputDim (" AttentionWeight" );
@@ -79,6 +86,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
79
86
" AttentionWeight shapes must be (%d + %d) * 1." , M, D);
80
87
PADDLE_ENFORCE_EQ (atten_w_dims[1 ], 1 ,
81
88
" AttentionWeight shapes must be (%d + %d) * 1." , M, D);
89
+
82
90
if (ctx->HasInput (" AttentionBias" )) {
83
91
auto atten_b_dims = ctx->GetInputDim (" AttentionBias" );
84
92
PADDLE_ENFORCE_EQ (atten_b_dims.size (), 2 ,
0 commit comments