@@ -21,59 +21,6 @@ namespace paddle {
21
21
namespace framework {
22
22
namespace ir {
23
23
24
- bool VarOutLinksToOp (Node* node, const std::string& op_type) {
25
- for (auto * out : node->outputs ) {
26
- if (out->IsOp () && out->Op ()->Type () == op_type) {
27
- return true ;
28
- }
29
- }
30
- return false ;
31
- }
32
-
33
- void BuildFCPattern (PDPattern* pattern) {
34
- // Create Operators
35
- auto * mul_op = pattern->NewNode (" mul" )->assert_is_op (" mul" );
36
- auto * elementwise_add_op =
37
- pattern->NewNode (" elementwise_add" )->assert_is_op (" elementwise_add" );
38
- // Create variables
39
- // w
40
- auto * mul_weight_var = pattern->NewNode (" mul_weight" )
41
- ->AsInput ()
42
- ->assert_is_op_nth_input (" mul" , " Y" , 0 );
43
- // x
44
- auto * mul_tmp_var = pattern->NewNode (" mul_tmp_var" )
45
- ->AsInput ()
46
- ->assert_is_op_nth_input (" mul" , " X" , 0 );
47
- // intermediate variable, will be removed in the IR after fuse.
48
- auto * mul_out_var = pattern->NewNode (" mul_out" )
49
- ->AsIntermediate ()
50
- ->assert_is_only_output_of_op (" mul" )
51
- ->assert_is_op_input (" elementwise_add" );
52
- // bias
53
- auto * elementwise_add_tmp_var = pattern->NewNode (" elementwise_add_tmpvar" )
54
- ->assert_is_op_input (" elementwise_add" )
55
- ->AsInput ();
56
- // output
57
- auto * elementwise_add_out_var = pattern->NewNode (" elementwise_add_out" )
58
- ->AsOutput ()
59
- ->assert_is_op_output (" elementwise_add" );
60
-
61
- mul_op->LinksFrom ({mul_weight_var, mul_tmp_var}).LinksTo ({mul_out_var});
62
- elementwise_add_op->LinksFrom ({mul_out_var, elementwise_add_tmp_var})
63
- .LinksTo ({elementwise_add_out_var});
64
- }
65
-
66
- // Replace the node `from` in the links to `to`
67
- bool LinksReplace (std::vector<Node*>* links, Node* from, Node* to) {
68
- for (auto *& n : *links) {
69
- if (n == from) {
70
- n = to;
71
- return true ;
72
- }
73
- }
74
- return false ;
75
- }
76
-
77
24
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl (
78
25
std::unique_ptr<ir::Graph> graph) const {
79
26
PADDLE_ENFORCE (graph.get ());
@@ -82,13 +29,18 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
82
29
std::unordered_set<Node*> nodes2delete;
83
30
84
31
GraphPatternDetector gpd;
85
- BuildFCPattern (gpd.mutable_pattern ());
86
-
87
- #define GET_NODE (id ) \
88
- PADDLE_ENFORCE (subgraph.count (gpd.pattern ().RetrieveNode (#id)), \
89
- " pattern has no Node called %s" , #id); \
90
- auto * id = subgraph.at (gpd.pattern ().RetrieveNode (#id)); \
91
- PADDLE_ENFORCE_NOT_NULL (id, " subgraph has no node %s" , #id);
32
+ // BuildFCPattern(gpd.mutable_pattern());
33
+ auto * x = gpd.mutable_pattern ()
34
+ ->NewNode (" fc_fuse/x" )
35
+ ->AsInput ()
36
+ ->assert_is_op_input (" mul" , " X" );
37
+ patterns::FC (gpd.mutable_pattern (), " fc_fuse" , x, true /* with bias*/ );
38
+
39
+ #define GET_NODE (id ) \
40
+ PADDLE_ENFORCE (subgraph.count (gpd.pattern ().RetrieveNode (" fc_fuse/" #id)), \
41
+ " pattern has no Node called %s" , #id); \
42
+ auto * id = subgraph.at (gpd.pattern ().RetrieveNode (" fc_fuse/" #id)); \
43
+ PADDLE_ENFORCE_NOT_NULL (id, " subgraph has no node %s" , " fc_fuse/" #id);
92
44
93
45
int found_fc_count = 0 ;
94
46
auto handler = [&](const GraphPatternDetector::subgraph_t & subgraph,
@@ -98,43 +50,33 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
98
50
// scenerio.
99
51
// FC's fusion is simple, just op fuse, no need to process the
100
52
// parameters.
101
- GET_NODE (mul_tmp_var); // x
102
- GET_NODE (mul_weight); // Y
103
- GET_NODE (elementwise_add_tmpvar); // bias
104
- GET_NODE (elementwise_add_out); // Out
105
- GET_NODE (mul); // MUL op
106
- GET_NODE (elementwise_add); // ELEMENT_ADD op
107
- GET_NODE (mul_out); // tmp
53
+ GET_NODE (x); // x
54
+ GET_NODE (w); // Y
55
+ GET_NODE (fc_bias); // bias
56
+ GET_NODE (fc_out); // Out
57
+ GET_NODE (mul); // MUL op
58
+ GET_NODE (elementwise_add); // ELEMENT_ADD op
59
+ GET_NODE (mul_out); // tmp
108
60
#undef GET_NODE
109
61
110
62
// Create an FC Node.
111
63
OpDesc desc;
112
- std::string fc_x_in = mul_tmp_var ->Name ();
113
- std::string fc_Y_in = mul_weight ->Name ();
114
- std::string fc_bias_in = elementwise_add_tmpvar ->Name ();
115
- std::string fc_out = elementwise_add_out ->Name ();
64
+ std::string fc_x_in = x ->Name ();
65
+ std::string fc_Y_in = w ->Name ();
66
+ std::string fc_bias_in = fc_bias ->Name ();
67
+ std::string fc_out_out = fc_out ->Name ();
116
68
desc.SetInput (" Input" , std::vector<std::string>({fc_x_in}));
117
69
desc.SetInput (" W" , std::vector<std::string>({fc_Y_in}));
118
70
desc.SetInput (" Bias" , std::vector<std::string>({fc_bias_in}));
119
- desc.SetOutput (" Out" , std::vector<std::string>({fc_out }));
71
+ desc.SetOutput (" Out" , std::vector<std::string>({fc_out_out }));
120
72
desc.SetType (" fc" );
121
73
auto fc_node = g->CreateOpNode (&desc); // OpDesc will be copied.
122
- fc_node->inputs =
123
- std::vector<Node*>({mul_tmp_var, mul_weight, elementwise_add_tmpvar});
124
- fc_node->outputs .push_back (elementwise_add_out);
125
-
126
- // Update link relatons
127
- PADDLE_ENFORCE (LinksReplace (&mul_tmp_var->outputs , mul, fc_node));
128
- PADDLE_ENFORCE (LinksReplace (&mul_weight->outputs , mul, fc_node));
129
- PADDLE_ENFORCE (LinksReplace (&elementwise_add_tmpvar->outputs ,
130
- elementwise_add, fc_node));
131
- PADDLE_ENFORCE (
132
- LinksReplace (&elementwise_add_out->inputs , elementwise_add, fc_node));
74
+ GraphSafeRemoveNodes (graph.get (), {mul, elementwise_add, mul_out});
133
75
134
- // Drop old nodes
135
- graph-> RemoveNode (mul );
136
- graph-> RemoveNode (elementwise_add );
137
- graph-> RemoveNode (mul_out); // tmp variable
76
+ IR_NODE_LINK_TO (x, fc_node);
77
+ IR_NODE_LINK_TO (w, fc_node );
78
+ IR_NODE_LINK_TO (fc_bias, fc_node );
79
+ IR_NODE_LINK_TO (fc_node, fc_out);
138
80
139
81
found_fc_count++;
140
82
};
0 commit comments