Skip to content

Commit 478a4e8

Browse files
authored
refactor ir pattern (#13304)
1 parent 14242ea commit 478a4e8

File tree

7 files changed

+316
-240
lines changed

7 files changed

+316
-240
lines changed

paddle/fluid/framework/ir/fc_fuse_pass.cc

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,39 +29,27 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
2929
std::unordered_set<Node*> nodes2delete;
3030

3131
GraphPatternDetector gpd;
32-
// BuildFCPattern(gpd.mutable_pattern());
3332
auto* x = gpd.mutable_pattern()
3433
->NewNode("fc_fuse/x")
3534
->AsInput()
3635
->assert_is_op_input("mul", "X");
37-
patterns::FC(gpd.mutable_pattern(), "fc_fuse", x, true /*with bias*/);
38-
39-
#define GET_NODE(id) \
40-
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode("fc_fuse/" #id)), \
41-
"pattern has no Node called %s", #id); \
42-
auto* id = subgraph.at(gpd.pattern().RetrieveNode("fc_fuse/" #id)); \
43-
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", "fc_fuse/" #id);
36+
patterns::FC fc_pattern(gpd.mutable_pattern(), "fc_fuse");
37+
fc_pattern(x, true /*with bias*/);
4438

4539
int found_fc_count = 0;
4640
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
4741
Graph* g) {
4842
VLOG(4) << "handle FC fuse";
49-
// Currently, there is no FC op available, so I will just simulate the
50-
// scenerio.
51-
// FC's fusion is simple, just op fuse, no need to process the
52-
// parameters.
53-
GET_NODE(x); // x
54-
GET_NODE(w); // Y
55-
GET_NODE(fc_bias); // bias
56-
GET_NODE(fc_out); // Out
57-
GET_NODE(mul); // MUL op
58-
GET_NODE(elementwise_add); // ELEMENT_ADD op
59-
GET_NODE(mul_out); // tmp
60-
#undef GET_NODE
43+
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
44+
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
45+
GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
46+
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
47+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
48+
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
6149

6250
// Create an FC Node.
6351
OpDesc desc;
64-
std::string fc_x_in = x->Name();
52+
std::string fc_x_in = subgraph.at(x)->Name();
6553
std::string fc_Y_in = w->Name();
6654
std::string fc_bias_in = fc_bias->Name();
6755
std::string fc_out_out = fc_out->Name();
@@ -73,7 +61,8 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
7361
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
7462
GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out});
7563

76-
IR_NODE_LINK_TO(x, fc_node);
64+
PADDLE_ENFORCE(subgraph.count(x));
65+
IR_NODE_LINK_TO(subgraph.at(x), fc_node);
7766
IR_NODE_LINK_TO(w, fc_node);
7867
IR_NODE_LINK_TO(fc_bias, fc_node);
7968
IR_NODE_LINK_TO(fc_node, fc_out);

paddle/fluid/framework/ir/fc_gru_fuse_pass.cc

Lines changed: 44 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -20,52 +20,43 @@ namespace paddle {
2020
namespace framework {
2121
namespace ir {
2222

23-
static void BuildPattern(PDPattern* pattern, const std::string& name_scope,
24-
bool with_fc_bias) {
25-
PDNode* x = pattern->NewNode(name_scope, "x")
26-
->assert_is_op_input("mul")
27-
->assert_var_not_persistable();
28-
auto* fc_out = patterns::FC(pattern, name_scope, x, with_fc_bias);
29-
fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse.
30-
patterns::GRU(pattern, name_scope, fc_out);
31-
VLOG(3) << "fc_gru pattern \n" << pattern->DotString();
32-
}
33-
3423
static int BuildFusion(Graph* graph, const std::string& name_scope,
3524
Scope* scope, bool with_fc_bias) {
3625
GraphPatternDetector gpd;
3726
auto* pattern = gpd.mutable_pattern();
3827

39-
BuildPattern(pattern, name_scope, with_fc_bias);
28+
// Create pattern.
29+
patterns::FC fc_pattern(pattern, name_scope);
30+
patterns::GRU gru_pattern(pattern, name_scope);
31+
32+
PDNode* x =
33+
pattern->NewNode(patterns::UniqueKey("x"))->assert_var_not_persistable();
34+
35+
auto* fc_out = fc_pattern(x, with_fc_bias);
36+
fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse.
37+
gru_pattern(fc_out);
4038

4139
// Create New OpDesc
42-
auto gru_creater = [&](int gru, int x, int weight_x, int weight_h, int bias,
43-
int hidden, int fc_bias) {
44-
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
45-
GET_NODE(x);
46-
GET_NODE(weight_x);
47-
GET_NODE(weight_h);
48-
GET_NODE(bias);
49-
GET_NODE(hidden);
50-
GET_NODE(gru);
40+
auto gru_creater = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h,
41+
Node* bias, Node* hidden, Node* fc_bias) {
5142

5243
OpDesc op_desc;
5344
op_desc.SetType("fusion_gru");
5445

5546
#define NEW_NAME(x) name_scope + "/at." #x ".new"
56-
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__##_n->Name()});
47+
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
5748
SET_IN(X, x);
5849
SET_IN(WeightX, weight_x);
5950
SET_IN(WeightH, weight_h);
6051
if (with_fc_bias) {
61-
op_desc.SetInput("Bias", {NEW_NAME(bias) + bias_n->Name()});
52+
op_desc.SetInput("Bias", {NEW_NAME(bias) + bias->Name()});
6253
} else {
6354
SET_IN(Bias, bias);
6455
}
6556
#undef SET_IN
6657
op_desc.SetInput("H0", {});
67-
op_desc.SetOutput("Hidden", {hidden_n->Name()});
68-
op_desc.SetAttr("is_reverse", gru_n->Op()->GetAttr("is_reverse"));
58+
op_desc.SetOutput("Hidden", {hidden->Name()});
59+
op_desc.SetAttr("is_reverse", gru->Op()->GetAttr("is_reverse"));
6960
// TODO(TJ): This should be a option for infer
7061
op_desc.SetAttr("use_seq", true);
7162

@@ -82,14 +73,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
8273
PADDLE_ENFORCE(scope);
8374
if (with_fc_bias) {
8475
// Fusion GRU bias = fcbias + grubias
85-
auto* fusion_bias_var = scope->Var(NEW_NAME(bias) + bias_n->Name());
76+
auto* fusion_bias_var = scope->Var(NEW_NAME(bias) + bias->Name());
8677
auto* out_bias_tensor =
8778
fusion_bias_var->GetMutable<framework::LoDTensor>();
8879
PADDLE_ENFORCE(fusion_bias_var);
89-
GET_NODE(fc_bias);
90-
PADDLE_ENFORCE(fc_bias_n);
91-
auto* gru_bias_var = scope->FindVar(bias_n->Name());
92-
auto* fc_bias_var = scope->FindVar(fc_bias_n->Name());
80+
auto* gru_bias_var = scope->FindVar(bias->Name());
81+
auto* fc_bias_var = scope->FindVar(fc_bias->Name());
9382
PADDLE_ENFORCE(gru_bias_var);
9483
PADDLE_ENFORCE(fc_bias_var);
9584
const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>();
@@ -113,54 +102,47 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
113102
#undef NEW_NAME
114103
#undef NEW_IMTERMEDIATE_OUT
115104

116-
IR_NODE_LINK_TO(x_n, op);
117-
IR_NODE_LINK_TO(weight_x_n, op);
118-
IR_NODE_LINK_TO(weight_h_n, op);
119-
IR_NODE_LINK_TO(bias_n, op); // actually should link to new bias if have
120-
IR_NODE_LINK_TO(op, hidden_n);
105+
IR_NODE_LINK_TO(x, op);
106+
IR_NODE_LINK_TO(weight_x, op);
107+
IR_NODE_LINK_TO(weight_h, op);
108+
IR_NODE_LINK_TO(bias, op); // actually should link to new bias if have
109+
IR_NODE_LINK_TO(op, hidden);
121110
// h0?
122111
return op;
123112
};
124113

125114
int fusion_count{0};
126115
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
127116
Graph* g) {
128-
#define GET_NODE(name__) \
129-
std::string name__##key = name_scope + "/" + #name__; \
130-
auto* name__##n = pattern->RetrieveNode(name__##key); \
131-
PADDLE_ENFORCE(name__##n); \
132-
PADDLE_ENFORCE(subgraph.count(name__##n)); \
133-
Node* name__##_n = subgraph.at(name__##n); \
134-
int name__ __attribute__((unused)) = name__##_n->id();
135-
136-
GET_NODE(x);
137-
GET_NODE(w); // fc weight
138-
GET_NODE(mul);
139-
GET_NODE(fc_out);
140-
GET_NODE(Weight);
141-
GET_NODE(gru);
142-
GET_NODE(Bias);
143-
GET_NODE(Hidden);
117+
auto* x_n = subgraph.at(x);
118+
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
119+
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
120+
GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
121+
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, gru_pattern);
122+
GET_IR_NODE_FROM_SUBGRAPH(gru, gru, gru_pattern);
123+
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, gru_pattern);
124+
GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, gru_pattern);
144125
// nodes need be removed
145-
GET_NODE(BatchGate);
146-
GET_NODE(BatchResetHiddenPrev);
147-
GET_NODE(BatchHidden);
126+
GET_IR_NODE_FROM_SUBGRAPH(BatchGate, BatchGate, gru_pattern);
127+
GET_IR_NODE_FROM_SUBGRAPH(BatchResetHiddenPrev, BatchGate, gru_pattern);
128+
GET_IR_NODE_FROM_SUBGRAPH(BatchHidden, BatchGate, gru_pattern);
148129

149130
if (with_fc_bias) {
150-
GET_NODE(mul_out);
151-
GET_NODE(fc_bias);
152-
GET_NODE(elementwise_add);
153-
gru_creater(gru, x, w, Weight, Bias, Hidden, fc_bias);
131+
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
132+
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
133+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
134+
135+
gru_creater(gru, x_n, w, Weight, Bias, Hidden, fc_bias);
154136
// Remove unneeded nodes.
155137
std::unordered_set<const Node*> marked_nodes(
156-
{mul_n, gru_n, elementwise_add_n, fc_bias_n, fc_out_n, mul_out_n,
157-
BatchGate_n, BatchResetHiddenPrev_n, BatchHidden_n});
138+
{mul, gru, elementwise_add, fc_bias, fc_out, mul_out, BatchGate,
139+
BatchResetHiddenPrev, BatchHidden});
158140
GraphSafeRemoveNodes(graph, marked_nodes);
159141
} else {
160-
gru_creater(gru, x, w, Weight, Bias, Hidden, -1);
142+
gru_creater(gru, x_n, w, Weight, Bias, Hidden, nullptr);
161143
// Remove unneeded nodes.
162144
std::unordered_set<const Node*> marked_nodes(
163-
{mul_n, gru_n, BatchGate_n, BatchResetHiddenPrev_n, BatchHidden_n});
145+
{mul, gru, BatchGate, BatchResetHiddenPrev, BatchHidden});
164146
GraphSafeRemoveNodes(graph, marked_nodes);
165147
}
166148
#undef GET_NODE

0 commit comments

Comments
 (0)