@@ -15,6 +15,7 @@ limitations under the License. */
15
15
#include " paddle/fluid/operators/fusion_gru_op.h"
16
16
#include < cstring> // for memcpy
17
17
#include < string>
18
+ #include " paddle/fluid/framework/shape_runtime_infer.h"
18
19
#include " paddle/fluid/operators/math/blas.h"
19
20
#include " paddle/fluid/operators/math/cpu_vec.h"
20
21
#include " paddle/fluid/operators/math/fc_compute.h"
@@ -25,14 +26,46 @@ namespace paddle {
25
26
namespace operators {
26
27
27
28
void FusionGRUOp::InferShape (framework::InferShapeContext* ctx) const {
28
- PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) of GRU should not be null." );
29
- PADDLE_ENFORCE (ctx->HasInput (" WeightX" ),
30
- " Input(WeightX) of GRU should not be null." );
31
- PADDLE_ENFORCE (ctx->HasInput (" WeightH" ),
32
- " Input(WeightH) of GRU should not be null." );
33
- PADDLE_ENFORCE (ctx->HasOutput (" XX" ), " Output(XX) of GRU should not be null." );
34
- PADDLE_ENFORCE (ctx->HasOutput (" Hidden" ),
35
- " Output(Hidden) of GRU should not be null." );
29
+ auto * runtime_ctx = dynamic_cast <framework::RuntimeInferShapeContext*>(ctx);
30
+ if (runtime_ctx == nullptr ) {
31
+ LOG (FATAL) << " Should have runtime infer context" ;
32
+ }
33
+ const auto & ins = runtime_ctx->OpBase ().Inputs ();
34
+ const auto & outs = runtime_ctx->OpBase ().Outputs ();
35
+ const auto & scope = runtime_ctx->InferScope ();
36
+ const auto ins_end = ins.end ();
37
+ const auto outs_end = outs.end ();
38
+ auto fair_input = [&](const std::string& name) -> bool {
39
+ auto it = ins.find (name);
40
+ if (it == ins_end) {
41
+ return false ;
42
+ }
43
+ const auto & in = it->second ;
44
+ if (in.size () != 1 || in[0 ] == framework::kEmptyVarName ) {
45
+ return false ;
46
+ }
47
+ return scope.FindVar (in[0 ]) != nullptr ;
48
+ };
49
+ auto fair_output = [&](const std::string& name) -> bool {
50
+ auto it = outs.find (name);
51
+ if (it == outs_end) {
52
+ return false ;
53
+ }
54
+ const auto & out = it->second ;
55
+ if (out.size () != 1 || out[0 ] == framework::kEmptyVarName ) {
56
+ return false ;
57
+ }
58
+ return scope.FindVar (out[0 ]) != nullptr ;
59
+ };
60
+
61
+ PADDLE_ENFORCE (fair_input (" X" ), " Assert only one Input(X) of GRU." );
62
+ PADDLE_ENFORCE (fair_input (" WeightX" ),
63
+ " Assert only one Input(WeightX) of GRU." );
64
+ PADDLE_ENFORCE (fair_input (" WeightH" ),
65
+ " Assert only one Input(WeightH) of GRU." );
66
+ PADDLE_ENFORCE (fair_output (" XX" ), " Assert only one Output(XX) of GRU." );
67
+ PADDLE_ENFORCE (fair_output (" Hidden" ),
68
+ " Assert only one Output(Hidden) of GRU." );
36
69
37
70
auto x_dims = ctx->GetInputDim (" X" );
38
71
PADDLE_ENFORCE_EQ (x_dims.size (), 2 , " Input(X)'s rank must be 2." );
@@ -58,12 +91,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
58
91
" should be 3 * %d." ,
59
92
frame_size);
60
93
61
- if (ctx-> HasInput (" H0" )) {
94
+ if (fair_input (" H0" )) {
62
95
auto h0_dims = ctx->GetInputDim (" H0" );
63
96
PADDLE_ENFORCE_EQ (h0_dims[1 ], frame_size,
64
97
" The width of H0 must be equal to frame_size." );
65
98
}
66
- if (ctx-> HasInput (" Bias" )) {
99
+ if (fair_input (" Bias" )) {
67
100
auto b_dims = ctx->GetInputDim (" Bias" );
68
101
PADDLE_ENFORCE_EQ (b_dims.size (), 2 , " The rank of Input(Bias) should be 2." );
69
102
PADDLE_ENFORCE_EQ (b_dims[0 ], 1 ,
@@ -79,12 +112,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
79
112
xx_width = wx_dims[1 ];
80
113
} else {
81
114
xx_width = x_dims[1 ] > wx_dims[1 ] ? wx_dims[1 ] : x_dims[1 ];
82
- PADDLE_ENFORCE (ctx-> HasOutput (" ReorderedH0" ),
83
- " Output(ReorderedH0) of GRU should not be null ." );
84
- PADDLE_ENFORCE (ctx-> HasOutput (" BatchedInput" ),
85
- " Output(BatchedInput) of GRU should not be null ." );
86
- PADDLE_ENFORCE (ctx-> HasOutput (" BatchedOut" ),
87
- " Output(BatchedOut) of GRU should not be null ." );
115
+ PADDLE_ENFORCE (fair_output (" ReorderedH0" ),
116
+ " Assert only one Output(ReorderedH0) of GRU." );
117
+ PADDLE_ENFORCE (fair_output (" BatchedInput" ),
118
+ " Assert only one Output(BatchedInput) of GRU." );
119
+ PADDLE_ENFORCE (fair_output (" BatchedOut" ),
120
+ " Assert only one Output(BatchedOut) of GRU." );
88
121
ctx->SetOutputDim (" BatchedInput" , {x_dims[0 ], wx_dims[1 ]});
89
122
ctx->SetOutputDim (" BatchedOut" , out_dims);
90
123
}
0 commit comments