Skip to content

Commit 2ef34c6

Browse files
authored
refine fc with pattern reusing (#13187)
1 parent 796c87d commit 2ef34c6

File tree

5 files changed

+49
-115
lines changed

5 files changed

+49
-115
lines changed

paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,13 @@ void FindWhileOp(Graph* graph) {
9999
auto* cell_init = graph->RetriveNode(6);
100100
auto* hidden_init = graph->RetriveNode(8);
101101

102-
#define LINK_TO(node0, node1) \
103-
node0->outputs.push_back(node1); \
104-
node1->inputs.push_back(node0);
105-
106102
auto* lstm_op = graph->CreateOpNode(&op_desc);
107103
PrepareParameters(graph, param);
108104

109-
LINK_TO(X, lstm_op);
110-
LINK_TO(cell_init, lstm_op);
111-
LINK_TO(hidden_init, lstm_op);
112-
LINK_TO(lstm_op, LSTMOUT);
105+
IR_NODE_LINK_TO(X, lstm_op);
106+
IR_NODE_LINK_TO(cell_init, lstm_op);
107+
IR_NODE_LINK_TO(hidden_init, lstm_op);
108+
IR_NODE_LINK_TO(lstm_op, LSTMOUT);
113109

114110
GraphSafeRemoveNodes(graph, marked_nodes);
115111
}

paddle/fluid/framework/ir/fc_fuse_pass.cc

Lines changed: 29 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -21,59 +21,6 @@ namespace paddle {
2121
namespace framework {
2222
namespace ir {
2323

24-
bool VarOutLinksToOp(Node* node, const std::string& op_type) {
25-
for (auto* out : node->outputs) {
26-
if (out->IsOp() && out->Op()->Type() == op_type) {
27-
return true;
28-
}
29-
}
30-
return false;
31-
}
32-
33-
void BuildFCPattern(PDPattern* pattern) {
34-
// Create Operators
35-
auto* mul_op = pattern->NewNode("mul")->assert_is_op("mul");
36-
auto* elementwise_add_op =
37-
pattern->NewNode("elementwise_add")->assert_is_op("elementwise_add");
38-
// Create variables
39-
// w
40-
auto* mul_weight_var = pattern->NewNode("mul_weight")
41-
->AsInput()
42-
->assert_is_op_nth_input("mul", "Y", 0);
43-
// x
44-
auto* mul_tmp_var = pattern->NewNode("mul_tmp_var")
45-
->AsInput()
46-
->assert_is_op_nth_input("mul", "X", 0);
47-
// intermediate variable, will be removed in the IR after fuse.
48-
auto* mul_out_var = pattern->NewNode("mul_out")
49-
->AsIntermediate()
50-
->assert_is_only_output_of_op("mul")
51-
->assert_is_op_input("elementwise_add");
52-
// bias
53-
auto* elementwise_add_tmp_var = pattern->NewNode("elementwise_add_tmpvar")
54-
->assert_is_op_input("elementwise_add")
55-
->AsInput();
56-
// output
57-
auto* elementwise_add_out_var = pattern->NewNode("elementwise_add_out")
58-
->AsOutput()
59-
->assert_is_op_output("elementwise_add");
60-
61-
mul_op->LinksFrom({mul_weight_var, mul_tmp_var}).LinksTo({mul_out_var});
62-
elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var})
63-
.LinksTo({elementwise_add_out_var});
64-
}
65-
66-
// Replace the node `from` in the links to `to`
67-
bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
68-
for (auto*& n : *links) {
69-
if (n == from) {
70-
n = to;
71-
return true;
72-
}
73-
}
74-
return false;
75-
}
76-
7724
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
7825
std::unique_ptr<ir::Graph> graph) const {
7926
PADDLE_ENFORCE(graph.get());
@@ -82,13 +29,18 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
8229
std::unordered_set<Node*> nodes2delete;
8330

8431
GraphPatternDetector gpd;
85-
BuildFCPattern(gpd.mutable_pattern());
86-
87-
#define GET_NODE(id) \
88-
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode(#id)), \
89-
"pattern has no Node called %s", #id); \
90-
auto* id = subgraph.at(gpd.pattern().RetrieveNode(#id)); \
91-
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
32+
// BuildFCPattern(gpd.mutable_pattern());
33+
auto* x = gpd.mutable_pattern()
34+
->NewNode("fc_fuse/x")
35+
->AsInput()
36+
->assert_is_op_input("mul", "X");
37+
patterns::FC(gpd.mutable_pattern(), "fc_fuse", x, true /*with bias*/);
38+
39+
#define GET_NODE(id) \
40+
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode("fc_fuse/" #id)), \
41+
"pattern has no Node called %s", #id); \
42+
auto* id = subgraph.at(gpd.pattern().RetrieveNode("fc_fuse/" #id)); \
43+
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", "fc_fuse/" #id);
9244

9345
int found_fc_count = 0;
9446
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
@@ -98,43 +50,33 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
9850
// scenerio.
9951
// FC's fusion is simple, just op fuse, no need to process the
10052
// parameters.
101-
GET_NODE(mul_tmp_var); // x
102-
GET_NODE(mul_weight); // Y
103-
GET_NODE(elementwise_add_tmpvar); // bias
104-
GET_NODE(elementwise_add_out); // Out
105-
GET_NODE(mul); // MUL op
106-
GET_NODE(elementwise_add); // ELEMENT_ADD op
107-
GET_NODE(mul_out); // tmp
53+
GET_NODE(x); // x
54+
GET_NODE(w); // Y
55+
GET_NODE(fc_bias); // bias
56+
GET_NODE(fc_out); // Out
57+
GET_NODE(mul); // MUL op
58+
GET_NODE(elementwise_add); // ELEMENT_ADD op
59+
GET_NODE(mul_out); // tmp
10860
#undef GET_NODE
10961

11062
// Create an FC Node.
11163
OpDesc desc;
112-
std::string fc_x_in = mul_tmp_var->Name();
113-
std::string fc_Y_in = mul_weight->Name();
114-
std::string fc_bias_in = elementwise_add_tmpvar->Name();
115-
std::string fc_out = elementwise_add_out->Name();
64+
std::string fc_x_in = x->Name();
65+
std::string fc_Y_in = w->Name();
66+
std::string fc_bias_in = fc_bias->Name();
67+
std::string fc_out_out = fc_out->Name();
11668
desc.SetInput("Input", std::vector<std::string>({fc_x_in}));
11769
desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
11870
desc.SetInput("Bias", std::vector<std::string>({fc_bias_in}));
119-
desc.SetOutput("Out", std::vector<std::string>({fc_out}));
71+
desc.SetOutput("Out", std::vector<std::string>({fc_out_out}));
12072
desc.SetType("fc");
12173
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
122-
fc_node->inputs =
123-
std::vector<Node*>({mul_tmp_var, mul_weight, elementwise_add_tmpvar});
124-
fc_node->outputs.push_back(elementwise_add_out);
125-
126-
// Update link relatons
127-
PADDLE_ENFORCE(LinksReplace(&mul_tmp_var->outputs, mul, fc_node));
128-
PADDLE_ENFORCE(LinksReplace(&mul_weight->outputs, mul, fc_node));
129-
PADDLE_ENFORCE(LinksReplace(&elementwise_add_tmpvar->outputs,
130-
elementwise_add, fc_node));
131-
PADDLE_ENFORCE(
132-
LinksReplace(&elementwise_add_out->inputs, elementwise_add, fc_node));
74+
GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out});
13375

134-
// Drop old nodes
135-
graph->RemoveNode(mul);
136-
graph->RemoveNode(elementwise_add);
137-
graph->RemoveNode(mul_out); // tmp variable
76+
IR_NODE_LINK_TO(x, fc_node);
77+
IR_NODE_LINK_TO(w, fc_node);
78+
IR_NODE_LINK_TO(fc_bias, fc_node);
79+
IR_NODE_LINK_TO(fc_node, fc_out);
13880

13981
found_fc_count++;
14082
};

paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
121121
#undef TMP_NEW
122122
#undef TMP_NAME
123123

124-
#define LINK_TO(a, b) \
125-
a->outputs.push_back(b); \
126-
b->inputs.push_back(a);
127-
LINK_TO(input_n, op);
128-
LINK_TO(weight_x_n, op);
129-
LINK_TO(weight_h_n, op);
130-
LINK_TO(bias_n, op);
131-
LINK_TO(op, hidden_n);
132-
#undef LINK_TO
124+
IR_NODE_LINK_TO(input_n, op);
125+
IR_NODE_LINK_TO(weight_x_n, op);
126+
IR_NODE_LINK_TO(weight_h_n, op);
127+
IR_NODE_LINK_TO(bias_n, op);
128+
IR_NODE_LINK_TO(op, hidden_n);
133129
return op;
134130
};
135131

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,10 @@ PDNode* LSTM(PDPattern* pattern, const std::string& name_scope, PDNode* x);
297297

298298
} // namespace patterns
299299

300+
#define IR_NODE_LINK_TO(a, b) \
301+
a->outputs.push_back(b); \
302+
b->inputs.push_back(a);
303+
300304
} // namespace ir
301305
} // namespace framework
302306
} // namespace paddle

paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -219,16 +219,13 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
219219
op_desc.SetAttr("fc_activation", act->Op()->Type());
220220

221221
auto* op_node = graph->CreateOpNode(&op_desc);
222-
// Add links
223-
#define NODE_LINKS(a, b) \
224-
a->outputs.push_back(b); \
225-
b->inputs.push_back(a);
226-
NODE_LINKS(fc_w, op_node);
227-
NODE_LINKS(fc_bias, op_node);
228-
NODE_LINKS(concat_in0, op_node);
229-
NODE_LINKS(sequence_expand0_in, op_node);
230-
NODE_LINKS(sequence_expand1_in, op_node);
231-
NODE_LINKS(op_node, fc_out);
222+
// Add links
223+
IR_NODE_LINK_TO(fc_w, op_node);
224+
IR_NODE_LINK_TO(fc_bias, op_node);
225+
IR_NODE_LINK_TO(concat_in0, op_node);
226+
IR_NODE_LINK_TO(sequence_expand0_in, op_node);
227+
IR_NODE_LINK_TO(sequence_expand1_in, op_node);
228+
IR_NODE_LINK_TO(op_node, fc_out);
232229

233230
// Clean nodes.
234231
std::unordered_set<const Node*> marked_nodes;
@@ -241,7 +238,6 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
241238
marked_nodes.erase(sequence_expand0_in);
242239
marked_nodes.erase(sequence_expand1_in);
243240
marked_nodes.erase(fc_out);
244-
245241
GraphSafeRemoveNodes(graph, marked_nodes);
246242
});
247243

0 commit comments

Comments
 (0)