@@ -14,6 +14,7 @@ limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/attention_lstm_op.h"
16
16
#include < string>
17
+ #include " paddle/fluid/framework/shape_runtime_infer.h"
17
18
#include " paddle/fluid/operators/math/blas.h"
18
19
#include " paddle/fluid/operators/math/cpu_vec.h"
19
20
#include " paddle/fluid/operators/math/fc_compute.h"
@@ -23,29 +24,60 @@ namespace paddle {
23
24
namespace operators {
24
25
25
26
void AttentionLSTMOp::InferShape (framework::InferShapeContext* ctx) const {
26
- PADDLE_ENFORCE (ctx->HasInput (" X" ),
27
- " Input(X) of AttentionLSTM should not be null." );
28
- PADDLE_ENFORCE (ctx->HasInput (" C0" ),
29
- " Input(C0) of AttentionLSTM should not be null." );
30
- PADDLE_ENFORCE (ctx->HasInput (" LSTMWeight" ),
31
- " Input(LSTMWeight) of AttentionLSTM should not be null." );
32
- PADDLE_ENFORCE (ctx->HasInput (" LSTMBias" ),
33
- " Input(LSTMBias) of AttentionLSTM should not be null." );
34
- PADDLE_ENFORCE (ctx->HasInput (" AttentionWeight" ),
35
- " Input(AttentionWeight) of AttentionLSTM should not be null." );
36
-
37
- PADDLE_ENFORCE (ctx->HasOutput (" Hidden" ),
38
- " Output(Hidden) of AttentionLSTM should not be null." );
39
- PADDLE_ENFORCE (ctx->HasOutput (" Cell" ),
40
- " Output(Cell) of AttentionLSTM should not be null." );
41
- PADDLE_ENFORCE (ctx->HasOutput (" AttentionedX" ),
42
- " Output(AttentionedX) of AttentionLSTM should not be null." );
43
- PADDLE_ENFORCE (ctx->HasOutput (" AttentionFCOut" ),
44
- " Output(AttentionFCOut) of AttentionLSTM should not be null." );
45
- PADDLE_ENFORCE (ctx->HasOutput (" LSTMX" ),
46
- " Output(LSTMX) of AttentionLSTM should not be null." );
47
- PADDLE_ENFORCE (ctx->HasOutput (" LSTMOUT" ),
48
- " Output(LSTMOUT) of AttentionLSTM should not be null." );
27
+ auto * runtime_ctx = dynamic_cast <framework::RuntimeInferShapeContext*>(ctx);
28
+ if (runtime_ctx == nullptr ) {
29
+ LOG (FATAL) << " Should have runtime infer context" ;
30
+ }
31
+ const auto & ins = runtime_ctx->OpBase ().Inputs ();
32
+ const auto & outs = runtime_ctx->OpBase ().Outputs ();
33
+ const auto & scope = runtime_ctx->InferScope ();
34
+ const auto ins_end = ins.end ();
35
+ const auto outs_end = outs.end ();
36
+ auto fair_input = [&](const std::string& name) -> bool {
37
+ auto it = ins.find (name);
38
+ if (it == ins_end) {
39
+ return false ;
40
+ }
41
+ const auto & in = it->second ;
42
+ if (in.size () != 1 || in[0 ] == framework::kEmptyVarName ) {
43
+ return false ;
44
+ }
45
+ return scope.FindVar (in[0 ]) != nullptr ;
46
+ };
47
+ auto fair_output = [&](const std::string& name) -> bool {
48
+ auto it = outs.find (name);
49
+ if (it == outs_end) {
50
+ return false ;
51
+ }
52
+ const auto & out = it->second ;
53
+ if (out.size () != 1 || out[0 ] == framework::kEmptyVarName ) {
54
+ return false ;
55
+ }
56
+ return scope.FindVar (out[0 ]) != nullptr ;
57
+ };
58
+
59
+ PADDLE_ENFORCE (fair_input (" X" ), " Assert only one Input(X) of AttentionLSTM." );
60
+ PADDLE_ENFORCE (fair_input (" C0" ),
61
+ " Assert only one Input(C0) of AttentionLSTM." );
62
+ PADDLE_ENFORCE (fair_input (" LSTMWeight" ),
63
+ " Assert only one Input(LSTMWeight) of AttentionLSTM." );
64
+ PADDLE_ENFORCE (fair_input (" LSTMBias" ),
65
+ " Assert only one Input(LSTMBias) of AttentionLSTM." );
66
+ PADDLE_ENFORCE (fair_input (" AttentionWeight" ),
67
+ " Assert only one Input(AttentionWeight) of AttentionLSTM." );
68
+
69
+ PADDLE_ENFORCE (fair_output (" Hidden" ),
70
+ " Assert only one Output(Hidden) of AttentionLSTM." );
71
+ PADDLE_ENFORCE (fair_output (" Cell" ),
72
+ " Assert only one Output(Cell) of AttentionLSTM." );
73
+ PADDLE_ENFORCE (fair_output (" AttentionedX" ),
74
+ " Assert only one Output(AttentionedX) of AttentionLSTM." );
75
+ PADDLE_ENFORCE (fair_output (" AttentionFCOut" ),
76
+ " Assert only one Output(AttentionFCOut) of AttentionLSTM." );
77
+ PADDLE_ENFORCE (fair_output (" LSTMX" ),
78
+ " Assert only one Output(LSTMX) of AttentionLSTM." );
79
+ PADDLE_ENFORCE (fair_output (" LSTMOUT" ),
80
+ " Assert only one Output(LSTMOUT) of AttentionLSTM." );
49
81
50
82
auto x_dims = ctx->GetInputDim (" X" );
51
83
const int M = x_dims[1 ];
@@ -65,7 +97,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
65
97
auto c_dims = ctx->GetInputDim (" C0" );
66
98
PADDLE_ENFORCE_EQ (c_dims.size (), 2 , " Input(C0)'s rank must be 2." );
67
99
PADDLE_ENFORCE_EQ (c_dims[1 ], D, " C0 dims should be N x %d." , D);
68
- if (ctx-> HasInput (" H0" )) {
100
+ if (fair_input (" H0" )) {
69
101
auto h_dims = ctx->GetInputDim (" H0" );
70
102
PADDLE_ENFORCE (h_dims == c_dims,
71
103
" The dimension of Input(H0) and Input(C0) "
@@ -79,7 +111,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
79
111
" AttentionWeight shapes must be (%d + %d) * 1." , M, D);
80
112
PADDLE_ENFORCE_EQ (atten_w_dims[1 ], 1 ,
81
113
" AttentionWeight shapes must be (%d + %d) * 1." , M, D);
82
- if (ctx-> HasInput (" AttentionBias" )) {
114
+ if (fair_input (" AttentionBias" )) {
83
115
auto atten_b_dims = ctx->GetInputDim (" AttentionBias" );
84
116
PADDLE_ENFORCE_EQ (atten_b_dims.size (), 2 ,
85
117
" Input(AttentionBias)'s rank must be 2." );
@@ -89,18 +121,18 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
89
121
" AttentionBias shapes must be 1 * 1." );
90
122
}
91
123
92
- if (ctx-> HasInput (" AttentionScalar" )) {
124
+ if (fair_input (" AttentionScalar" )) {
93
125
auto dims = ctx->GetInputDim (" AttentionScalar" );
94
126
PADDLE_ENFORCE_EQ (dims.size (), 2 ,
95
127
" Input(AttentionScalar)'s rank must be 2." );
96
128
PADDLE_ENFORCE_EQ (dims[0 ], 1 , " AttentionScalar shapes must be 1 * 1." );
97
129
PADDLE_ENFORCE_EQ (dims[1 ], 1 , " AttentionScalar shapes must be 1 * 1." );
98
130
}
99
131
100
- if (ctx-> HasInput (" AttentionScalarBias" )) {
132
+ if (fair_input (" AttentionScalarBias" )) {
101
133
auto dims = ctx->GetInputDim (" AttentionScalarBias" );
102
134
PADDLE_ENFORCE (
103
- ctx-> HasInput (" AttentionScalar" ),
135
+ fair_input (" AttentionScalar" ),
104
136
" AttentionScalar should not be null when have AttentionScalarBias." );
105
137
PADDLE_ENFORCE_EQ (dims.size (), 2 ,
106
138
" Input(AttentionScalarBias)'s rank must be 2." );
0 commit comments