Skip to content

Commit a5556d4

Browse files
committed
refine attentionlstm infershape
1 parent e0436ad commit a5556d4

File tree

1 file changed

+60
-28
lines changed

1 file changed

+60
-28
lines changed

paddle/fluid/operators/attention_lstm_op.cc

Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/attention_lstm_op.h"
1616
#include <string>
17+
#include "paddle/fluid/framework/shape_runtime_infer.h"
1718
#include "paddle/fluid/operators/math/blas.h"
1819
#include "paddle/fluid/operators/math/cpu_vec.h"
1920
#include "paddle/fluid/operators/math/fc_compute.h"
@@ -23,29 +24,60 @@ namespace paddle {
2324
namespace operators {
2425

2526
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
26-
PADDLE_ENFORCE(ctx->HasInput("X"),
27-
"Input(X) of AttentionLSTM should not be null.");
28-
PADDLE_ENFORCE(ctx->HasInput("C0"),
29-
"Input(C0) of AttentionLSTM should not be null.");
30-
PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
31-
"Input(LSTMWeight) of AttentionLSTM should not be null.");
32-
PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
33-
"Input(LSTMBias) of AttentionLSTM should not be null.");
34-
PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
35-
"Input(AttentionWeight) of AttentionLSTM should not be null.");
36-
37-
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
38-
"Output(Hidden) of AttentionLSTM should not be null.");
39-
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
40-
"Output(Cell) of AttentionLSTM should not be null.");
41-
PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
42-
"Output(AttentionedX) of AttentionLSTM should not be null.");
43-
PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
44-
"Output(AttentionFCOut) of AttentionLSTM should not be null.");
45-
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
46-
"Output(LSTMX) of AttentionLSTM should not be null.");
47-
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
48-
"Output(LSTMOUT) of AttentionLSTM should not be null.");
27+
auto* runtime_ctx = dynamic_cast<framework::RuntimeInferShapeContext*>(ctx);
28+
if (runtime_ctx == nullptr) {
29+
LOG(FATAL) << "Should have runtime infer context";
30+
}
31+
const auto& ins = runtime_ctx->OpBase().Inputs();
32+
const auto& outs = runtime_ctx->OpBase().Outputs();
33+
const auto& scope = runtime_ctx->InferScope();
34+
const auto ins_end = ins.end();
35+
const auto outs_end = outs.end();
36+
auto fair_input = [&](const std::string& name) -> bool {
37+
auto it = ins.find(name);
38+
if (it == ins_end) {
39+
return false;
40+
}
41+
const auto& in = it->second;
42+
if (in.size() != 1 || in[0] == framework::kEmptyVarName) {
43+
return false;
44+
}
45+
return scope.FindVar(in[0]) != nullptr;
46+
};
47+
auto fair_output = [&](const std::string& name) -> bool {
48+
auto it = outs.find(name);
49+
if (it == outs_end) {
50+
return false;
51+
}
52+
const auto& out = it->second;
53+
if (out.size() != 1 || out[0] == framework::kEmptyVarName) {
54+
return false;
55+
}
56+
return scope.FindVar(out[0]) != nullptr;
57+
};
58+
59+
PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of AttentionLSTM.");
60+
PADDLE_ENFORCE(fair_input("C0"),
61+
"Assert only one Input(C0) of AttentionLSTM.");
62+
PADDLE_ENFORCE(fair_input("LSTMWeight"),
63+
"Assert only one Input(LSTMWeight) of AttentionLSTM.");
64+
PADDLE_ENFORCE(fair_input("LSTMBias"),
65+
"Assert only one Input(LSTMBias) of AttentionLSTM.");
66+
PADDLE_ENFORCE(fair_input("AttentionWeight"),
67+
"Assert only one Input(AttentionWeight) of AttentionLSTM.");
68+
69+
PADDLE_ENFORCE(fair_output("Hidden"),
70+
"Assert only one Output(Hidden) of AttentionLSTM.");
71+
PADDLE_ENFORCE(fair_output("Cell"),
72+
"Assert only one Output(Cell) of AttentionLSTM.");
73+
PADDLE_ENFORCE(fair_output("AttentionedX"),
74+
"Assert only one Output(AttentionedX) of AttentionLSTM.");
75+
PADDLE_ENFORCE(fair_output("AttentionFCOut"),
76+
"Assert only one Output(AttentionFCOut) of AttentionLSTM.");
77+
PADDLE_ENFORCE(fair_output("LSTMX"),
78+
"Assert only one Output(LSTMX) of AttentionLSTM.");
79+
PADDLE_ENFORCE(fair_output("LSTMOUT"),
80+
"Assert only one Output(LSTMOUT) of AttentionLSTM.");
4981

5082
auto x_dims = ctx->GetInputDim("X");
5183
const int M = x_dims[1];
@@ -65,7 +97,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
6597
auto c_dims = ctx->GetInputDim("C0");
6698
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
6799
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
68-
if (ctx->HasInput("H0")) {
100+
if (fair_input("H0")) {
69101
auto h_dims = ctx->GetInputDim("H0");
70102
PADDLE_ENFORCE(h_dims == c_dims,
71103
"The dimension of Input(H0) and Input(C0) "
@@ -79,7 +111,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
79111
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
80112
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
81113
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
82-
if (ctx->HasInput("AttentionBias")) {
114+
if (fair_input("AttentionBias")) {
83115
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
84116
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
85117
"Input(AttentionBias)'s rank must be 2.");
@@ -89,18 +121,18 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
89121
"AttentionBias shapes must be 1 * 1.");
90122
}
91123

92-
if (ctx->HasInput("AttentionScalar")) {
124+
if (fair_input("AttentionScalar")) {
93125
auto dims = ctx->GetInputDim("AttentionScalar");
94126
PADDLE_ENFORCE_EQ(dims.size(), 2,
95127
"Input(AttentionScalar)'s rank must be 2.");
96128
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1.");
97129
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
98130
}
99131

100-
if (ctx->HasInput("AttentionScalarBias")) {
132+
if (fair_input("AttentionScalarBias")) {
101133
auto dims = ctx->GetInputDim("AttentionScalarBias");
102134
PADDLE_ENFORCE(
103-
ctx->HasInput("AttentionScalar"),
135+
fair_input("AttentionScalar"),
104136
"AttentionScalar should not be null when have AttentionScalarBias.");
105137
PADDLE_ENFORCE_EQ(dims.size(), 2,
106138
"Input(AttentionScalarBias)'s rank must be 2.");

0 commit comments

Comments
 (0)