@@ -31,77 +31,34 @@ bool VarOutLinksToOp(Node* node, const std::string& op_type) {
31
31
}
32
32
33
33
void BuildFCPattern (PDPattern* pattern) {
34
- // make sure the selected MUL op has one input argument is a parameter.
35
- auto * mul_parameter_var = pattern->NewNode (
36
- [](Node* node) {
37
- return node->IsVar () && node->outputs .size () == 1UL &&
38
- node->outputs .front ()->Op ()->Type () == " mul" && node->Var () &&
39
- node->Var ()->Persistable (); // check is a parameter
40
- },
41
- " mul_weight" /* name*/ );
42
-
43
- auto * mul_tmp_input_var = pattern->NewNode (
44
- [](Node* node) {
45
- bool result =
46
- node->IsVar () && node->outputs .size () >= 1UL && node->Var () &&
47
- !node->Var ()->Persistable (); // this input is not an parameter.
48
- if (!result) return false ;
49
- // check whether one output is MUL op.
50
- for (auto * op : node->outputs ) {
51
- if (op->IsOp () && op->Op ()->Type () == " mul" ) return true ;
52
- }
53
- return false ;
54
- },
55
- " mul_tmp_var" /* name*/ );
56
-
57
- // select a MUL op
58
- auto * mul_op = pattern->NewNode (
59
- [](Node* node) {
60
- return node->IsOp () && // start from an Op
61
- node->Op ()->Type () == " mul" ; // type is mul
62
- // the output should be consumed only by one element_add, that check
63
- // leaves in a Var PDNode.
64
- },
65
- " mul" /* name*/ );
66
-
67
- // make sure the MUL op's output has only one consumer and links to an
68
- // ELEMENTWISE_ADD op.
69
- auto * mul_out_var = pattern->NewNode (
70
- [](Node* node) {
71
- return node->IsVar () && // starts from a Var
72
- node->outputs .size () == 1UL && // only has one consumer
73
- node->outputs .front ()->IsOp () && // check basic logic
74
- node->Var () && // not a ControlDepVar
75
- node->outputs .front ()->Op ()->Type () ==
76
- " elementwise_add" ; // a very strong validation
77
- },
78
- " mul_out" );
79
- // this check is not essential, just to make the corresponding variable Node
80
- // retrival easier.
81
- auto * elementwise_add_tmp_var = pattern->NewNode (
82
- [](Node* node) {
83
- return node->IsVar () && node->outputs .size () >= 1UL && node->Var () &&
84
- VarOutLinksToOp (node, " elementwise_add" );
85
- },
86
- " elementwise_add_tmpvar" );
87
-
88
- // select an ELEMENTWISE_ADD op
89
- auto * elementwise_add_op = pattern->NewNode (
90
- [](Node* node) {
91
- return node->IsOp () && node->Op ()->Type () == " elementwise_add" ;
92
- },
93
- " elementwise_add" /* name*/ );
94
-
95
- // get the ELEMENTWISE_ADD op's output
96
- auto * elementwise_add_out_var = pattern->NewNode (
97
- [](Node* node) {
98
- return node->IsVar () && node->inputs .size () == 1UL && node->Var () &&
99
- node->inputs .front ()->Op ()->Type () == " elementwise_add" ;
100
- },
101
- " elementwise_add_out" );
102
-
103
- mul_op->LinksFrom ({mul_parameter_var, mul_tmp_input_var})
104
- .LinksTo ({mul_out_var});
34
+ // Create Operators
35
+ auto * mul_op = pattern->NewNode (" mul" )->assert_is_op (" mul" );
36
+ auto * elementwise_add_op =
37
+ pattern->NewNode (" elementwise_add" )->assert_is_op (" elementwise_add" );
38
+ // Create variables
39
+ // w
40
+ auto * mul_weight_var = pattern->NewNode (" mul_weight" )
41
+ ->AsInput ()
42
+ ->assert_is_op_nth_input (" mul" , " Y" , 0 );
43
+ // x
44
+ auto * mul_tmp_var = pattern->NewNode (" mul_tmp_var" )
45
+ ->AsInput ()
46
+ ->assert_is_op_nth_input (" mul" , " X" , 0 );
47
+ // intermediate variable, will be removed in the IR after fuse.
48
+ auto * mul_out_var = pattern->NewNode (" mul_out" )
49
+ ->AsIntermediate ()
50
+ ->assert_is_only_output_of_op (" mul" )
51
+ ->assert_is_op_input (" elementwise_add" );
52
+ // bias
53
+ auto * elementwise_add_tmp_var = pattern->NewNode (" elementwise_add_tmpvar" )
54
+ ->assert_is_op_input (" elementwise_add" )
55
+ ->AsInput ();
56
+ // output
57
+ auto * elementwise_add_out_var = pattern->NewNode (" elementwise_add_out" )
58
+ ->AsOutput ()
59
+ ->assert_is_op_output (" elementwise_add" );
60
+
61
+ mul_op->LinksFrom ({mul_weight_var, mul_tmp_var}).LinksTo ({mul_out_var});
105
62
elementwise_add_op->LinksFrom ({mul_out_var, elementwise_add_tmp_var})
106
63
.LinksTo ({elementwise_add_out_var});
107
64
}
@@ -120,18 +77,20 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
120
77
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl (
121
78
std::unique_ptr<ir::Graph> graph) const {
122
79
PADDLE_ENFORCE (graph.get ());
80
+ FusePassBase::Init (" fc" , graph.get ());
123
81
124
82
std::unordered_set<Node*> nodes2delete;
125
83
126
84
GraphPatternDetector gpd;
127
85
BuildFCPattern (gpd.mutable_pattern ());
128
86
129
- #define GET_NODE (id ) \
130
- PADDLE_ENFORCE (subgraph.count (gpd.pattern ().RetriveNode (#id)), \
131
- " pattern has no Node called %s" , #id); \
132
- auto * id = subgraph.at (gpd.pattern ().RetriveNode (#id)); \
87
+ #define GET_NODE (id ) \
88
+ PADDLE_ENFORCE (subgraph.count (gpd.pattern ().RetrieveNode (#id)), \
89
+ " pattern has no Node called %s" , #id); \
90
+ auto * id = subgraph.at (gpd.pattern ().RetrieveNode (#id)); \
133
91
PADDLE_ENFORCE_NOT_NULL (id, " subgraph has no node %s" , #id);
134
92
93
+ int found_fc_count = 0 ;
135
94
auto handler = [&](const GraphPatternDetector::subgraph_t & subgraph,
136
95
Graph* g) {
137
96
VLOG (4 ) << " handle FC fuse" ;
@@ -176,10 +135,13 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
176
135
graph->RemoveNode (mul);
177
136
graph->RemoveNode (elementwise_add);
178
137
graph->RemoveNode (mul_out); // tmp variable
138
+
139
+ found_fc_count++;
179
140
};
180
141
181
142
gpd (graph.get (), handler);
182
143
144
+ AddStatis (found_fc_count);
183
145
return graph;
184
146
}
185
147
0 commit comments