Skip to content

Commit 916f42b

Browse files
committed
refine fusion gru infershape
1 parent a5556d4 commit 916f42b

File tree

1 file changed

+49
-16
lines changed

1 file changed

+49
-16
lines changed

paddle/fluid/operators/fusion_gru_op.cc

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/fusion_gru_op.h"
1616
#include <cstring> // for memcpy
1717
#include <string>
18+
#include "paddle/fluid/framework/shape_runtime_infer.h"
1819
#include "paddle/fluid/operators/math/blas.h"
1920
#include "paddle/fluid/operators/math/cpu_vec.h"
2021
#include "paddle/fluid/operators/math/fc_compute.h"
@@ -25,14 +26,46 @@ namespace paddle {
2526
namespace operators {
2627

2728
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
28-
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null.");
29-
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
30-
"Input(WeightX) of GRU should not be null.");
31-
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
32-
"Input(WeightH) of GRU should not be null.");
33-
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
34-
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
35-
"Output(Hidden) of GRU should not be null.");
29+
auto* runtime_ctx = dynamic_cast<framework::RuntimeInferShapeContext*>(ctx);
30+
if (runtime_ctx == nullptr) {
31+
LOG(FATAL) << "Should have runtime infer context";
32+
}
33+
const auto& ins = runtime_ctx->OpBase().Inputs();
34+
const auto& outs = runtime_ctx->OpBase().Outputs();
35+
const auto& scope = runtime_ctx->InferScope();
36+
const auto ins_end = ins.end();
37+
const auto outs_end = outs.end();
38+
auto fair_input = [&](const std::string& name) -> bool {
39+
auto it = ins.find(name);
40+
if (it == ins_end) {
41+
return false;
42+
}
43+
const auto& in = it->second;
44+
if (in.size() != 1 || in[0] == framework::kEmptyVarName) {
45+
return false;
46+
}
47+
return scope.FindVar(in[0]) != nullptr;
48+
};
49+
auto fair_output = [&](const std::string& name) -> bool {
50+
auto it = outs.find(name);
51+
if (it == outs_end) {
52+
return false;
53+
}
54+
const auto& out = it->second;
55+
if (out.size() != 1 || out[0] == framework::kEmptyVarName) {
56+
return false;
57+
}
58+
return scope.FindVar(out[0]) != nullptr;
59+
};
60+
61+
PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of GRU.");
62+
PADDLE_ENFORCE(fair_input("WeightX"),
63+
"Assert only one Input(WeightX) of GRU.");
64+
PADDLE_ENFORCE(fair_input("WeightH"),
65+
"Assert only one Input(WeightH) of GRU.");
66+
PADDLE_ENFORCE(fair_output("XX"), "Assert only one Output(XX) of GRU.");
67+
PADDLE_ENFORCE(fair_output("Hidden"),
68+
"Assert only one Output(Hidden) of GRU.");
3669

3770
auto x_dims = ctx->GetInputDim("X");
3871
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
@@ -58,12 +91,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
5891
"should be 3 * %d.",
5992
frame_size);
6093

61-
if (ctx->HasInput("H0")) {
94+
if (fair_input("H0")) {
6295
auto h0_dims = ctx->GetInputDim("H0");
6396
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
6497
"The width of H0 must be equal to frame_size.");
6598
}
66-
if (ctx->HasInput("Bias")) {
99+
if (fair_input("Bias")) {
67100
auto b_dims = ctx->GetInputDim("Bias");
68101
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
69102
PADDLE_ENFORCE_EQ(b_dims[0], 1,
@@ -79,12 +112,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
79112
xx_width = wx_dims[1];
80113
} else {
81114
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
82-
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
83-
"Output(ReorderedH0) of GRU should not be null.");
84-
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
85-
"Output(BatchedInput) of GRU should not be null.");
86-
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
87-
"Output(BatchedOut) of GRU should not be null.");
115+
PADDLE_ENFORCE(fair_output("ReorderedH0"),
116+
"Assert only one Output(ReorderedH0) of GRU.");
117+
PADDLE_ENFORCE(fair_output("BatchedInput"),
118+
"Assert only one Output(BatchedInput) of GRU.");
119+
PADDLE_ENFORCE(fair_output("BatchedOut"),
120+
"Assert only one Output(BatchedOut) of GRU.");
88121
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
89122
ctx->SetOutputDim("BatchedOut", out_dims);
90123
}

0 commit comments

Comments
 (0)