@@ -30,14 +30,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
30
30
" Input(WeightX) of GRU should not be null." );
31
31
PADDLE_ENFORCE (ctx->HasInput (" WeightH" ),
32
32
" Input(WeightH) of GRU should not be null." );
33
-
34
33
PADDLE_ENFORCE (ctx->HasOutput (" XX" ), " Output(XX) of GRU should not be null." );
35
- PADDLE_ENFORCE (ctx->HasOutput (" ReorderedH0" ),
36
- " Output(ReorderedH0) of GRU should not be null." );
37
- PADDLE_ENFORCE (ctx->HasOutput (" BatchedInput" ),
38
- " Output(BatchedInput) of GRU should not be null." );
39
- PADDLE_ENFORCE (ctx->HasOutput (" BatchedOut" ),
40
- " Output(BatchedOut) of GRU should not be null." );
41
34
PADDLE_ENFORCE (ctx->HasOutput (" Hidden" ),
42
35
" Output(Hidden) of GRU should not be null." );
43
36
@@ -80,15 +73,20 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
80
73
}
81
74
framework::DDim out_dims ({x_dims[0 ], frame_size});
82
75
ctx->SetOutputDim (" Hidden" , out_dims);
83
- ctx->SetOutputDim (" BatchedInput" , {x_dims[0 ], wx_dims[1 ]});
84
- ctx->SetOutputDim (" BatchedOut" , out_dims);
85
76
ctx->ShareLoD (" X" , " Hidden" );
86
-
87
77
int xx_width;
88
78
if (ctx->Attrs ().Get <bool >(" use_seq" )) {
89
79
xx_width = wx_dims[1 ];
90
80
} else {
91
81
xx_width = x_dims[1 ] > wx_dims[1 ] ? wx_dims[1 ] : x_dims[1 ];
82
+ PADDLE_ENFORCE (ctx->HasOutput (" ReorderedH0" ),
83
+ " Output(ReorderedH0) of GRU should not be null." );
84
+ PADDLE_ENFORCE (ctx->HasOutput (" BatchedInput" ),
85
+ " Output(BatchedInput) of GRU should not be null." );
86
+ PADDLE_ENFORCE (ctx->HasOutput (" BatchedOut" ),
87
+ " Output(BatchedOut) of GRU should not be null." );
88
+ ctx->SetOutputDim (" BatchedInput" , {x_dims[0 ], wx_dims[1 ]});
89
+ ctx->SetOutputDim (" BatchedOut" , out_dims);
92
90
}
93
91
ctx->SetOutputDim (" XX" , {x_dims[0 ], xx_width});
94
92
ctx->ShareLoD (" X" , " XX" );
0 commit comments