Skip to content

Commit 62a9821

Browse files
authored
Merge pull request #13346 from tensor-tang/refine/infershape
Refine/infershape
2 parents 6abe03b + b0b5f51 commit 62a9821

File tree

4 files changed

+48
-51
lines changed

4 files changed

+48
-51
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -464,35 +464,35 @@ class RuntimeInferShapeContext : public InferShapeContext {
464464
: op_(op), scope_(scope) {}
465465

466466
bool HasInput(const std::string& name) const override {
467-
if (!op_.HasInputs(name)) {
467+
// has only one input
468+
const auto& ins = op_.Inputs();
469+
auto it = ins.find(name);
470+
if (it == ins.end()) {
468471
return false;
469472
}
470-
auto& ins = Inputs(name);
471-
size_t length = ins.size();
472-
if (length == 0) {
473+
const auto& in = it->second;
474+
if (in.size() == 0 || in[0] == kEmptyVarName) {
473475
return false;
474476
}
475-
PADDLE_ENFORCE_EQ(length, 1UL,
477+
PADDLE_ENFORCE_EQ(in.size(), 1UL,
476478
"Input %s should not have more than one inputs", name);
477-
auto ipt = ins[0];
478-
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
479-
return var != nullptr;
479+
return scope_.FindVar(in[0]) != nullptr;
480480
}
481481

482482
bool HasOutput(const std::string& name) const override {
483-
if (!op_.HasOutputs(name)) {
483+
// has only one output
484+
const auto& outs = op_.Outputs();
485+
auto it = outs.find(name);
486+
if (it == outs.end()) {
484487
return false;
485488
}
486-
auto& outs = Outputs(name);
487-
size_t length = outs.size();
488-
if (length == 0) {
489+
const auto& out = it->second;
490+
if (out.size() == 0 || out[0] == kEmptyVarName) {
489491
return false;
490492
}
491-
PADDLE_ENFORCE_EQ(length, 1UL,
492-
"Output %s should not have more than one inputs", name);
493-
auto ipt = outs[0];
494-
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
495-
return var != nullptr;
493+
PADDLE_ENFORCE_EQ(out.size(), 1UL,
494+
"Output %s should not have more than one outputs", name);
495+
return scope_.FindVar(out[0]) != nullptr;
496496
}
497497

498498
bool HasInputs(const std::string& name) const override {

paddle/fluid/operators/attention_lstm_op.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,28 +24,28 @@ namespace operators {
2424

2525
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
2626
PADDLE_ENFORCE(ctx->HasInput("X"),
27-
"Input(X) of AttentionLSTM should not be null.");
27+
"Assert only one Input(X) of AttentionLSTM.");
2828
PADDLE_ENFORCE(ctx->HasInput("C0"),
29-
"Input(C0) of AttentionLSTM should not be null.");
29+
"Assert only one Input(C0) of AttentionLSTM.");
3030
PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
31-
"Input(LSTMWeight) of AttentionLSTM should not be null.");
31+
"Assert only one Input(LSTMWeight) of AttentionLSTM.");
3232
PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
33-
"Input(LSTMBias) of AttentionLSTM should not be null.");
33+
"Assert only one Input(LSTMBias) of AttentionLSTM.");
3434
PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
35-
"Input(AttentionWeight) of AttentionLSTM should not be null.");
35+
"Assert only one Input(AttentionWeight) of AttentionLSTM.");
3636

3737
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
38-
"Output(Hidden) of AttentionLSTM should not be null.");
38+
"Assert only one Output(Hidden) of AttentionLSTM.");
3939
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
40-
"Output(Cell) of AttentionLSTM should not be null.");
40+
"Assert only one Output(Cell) of AttentionLSTM.");
4141
PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
42-
"Output(AttentionedX) of AttentionLSTM should not be null.");
42+
"Assert only one Output(AttentionedX) of AttentionLSTM.");
4343
PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
44-
"Output(AttentionFCOut) of AttentionLSTM should not be null.");
44+
"Assert only one Output(AttentionFCOut) of AttentionLSTM.");
4545
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
46-
"Output(LSTMX) of AttentionLSTM should not be null.");
46+
"Assert only one Output(LSTMX) of AttentionLSTM.");
4747
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
48-
"Output(LSTMOUT) of AttentionLSTM should not be null.");
48+
"Assert only one Output(LSTMOUT) of AttentionLSTM.");
4949

5050
auto x_dims = ctx->GetInputDim("X");
5151
const int M = x_dims[1];

paddle/fluid/operators/fusion_gru_op.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ namespace paddle {
2525
namespace operators {
2626

2727
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
28-
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null.");
28+
PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of GRU.");
2929
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
30-
"Input(WeightX) of GRU should not be null.");
30+
"Assert only one Input(WeightX) of GRU.");
3131
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
32-
"Input(WeightH) of GRU should not be null.");
33-
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
32+
"Assert only one Input(WeightH) of GRU.");
33+
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of GRU.");
3434
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
35-
"Output(Hidden) of GRU should not be null.");
35+
"Assert only one Output(Hidden) of GRU.");
3636

3737
auto x_dims = ctx->GetInputDim("X");
3838
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
@@ -80,11 +80,11 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
8080
} else {
8181
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
8282
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
83-
"Output(ReorderedH0) of GRU should not be null.");
83+
"Assert only one Output(ReorderedH0) of GRU.");
8484
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
85-
"Output(BatchedInput) of GRU should not be null.");
85+
"Assert only one Output(BatchedInput) of GRU.");
8686
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
87-
"Output(BatchedOut) of GRU should not be null.");
87+
"Assert only one Output(BatchedOut) of GRU.");
8888
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
8989
ctx->SetOutputDim("BatchedOut", out_dims);
9090
}

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,17 @@ namespace paddle {
2424
namespace operators {
2525

2626
void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
27-
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null.");
27+
PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of LSTM.");
2828
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
29-
"Input(WeightX) of LSTM should not be null.");
29+
"Assert only one Input(WeightX) of LSTM.");
3030
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
31-
"Input(WeightH) of LSTM should not be null.");
32-
PADDLE_ENFORCE(ctx->HasInput("Bias"),
33-
"Input(Bias) of LSTM should not be null.");
34-
35-
PADDLE_ENFORCE(ctx->HasOutput("XX"),
36-
"Output(XX) of LSTM should not be null.");
31+
"Assert only one Input(WeightH) of LSTM.");
32+
PADDLE_ENFORCE(ctx->HasInput("Bias"), "Assert only one Input(Bias) of LSTM.");
33+
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of LSTM.");
3734
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
38-
"Output(Hidden) of LSTM should not be null.");
35+
"Assert only one Output(Hidden) of LSTM.");
3936
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
40-
"Output(Cell) of LSTM should not be null.");
37+
"Assert only one Output(Cell) of LSTM.");
4138

4239
auto x_dims = ctx->GetInputDim("X");
4340
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
@@ -96,15 +93,15 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
9693
} else {
9794
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
9895
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
99-
"Output(BatchedInput) of LSTM should not be null.");
96+
"Assert only one Output(BatchedInput) of LSTM.");
10097
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
101-
"Output(BatchedHidden) of LSTM should not be null.");
98+
"Assert only one Output(BatchedHidden) of LSTM.");
10299
PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"),
103-
"Output(BatchedCell) of LSTM should not be null.");
100+
"Assert only one Output(BatchedCell) of LSTM.");
104101
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
105-
"Output(ReorderedH0) of LSTM should not be null.");
102+
"Assert only one Output(ReorderedH0) of LSTM");
106103
PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"),
107-
"Output(ReorderedC0) of LSTM should not be null.");
104+
"Assert only one Output(ReorderedC0) of LSTM.");
108105
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
109106
ctx->SetOutputDim("BatchedHidden", out_dims);
110107
ctx->SetOutputDim("BatchedCell", out_dims);

0 commit comments

Comments
 (0)