13
13
// limitations under the License.
14
14
15
15
#include " paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
16
+ #include " paddle/fluid/framework/lod_tensor.h"
16
17
17
18
namespace paddle {
18
19
namespace framework {
19
20
namespace ir {
20
21
21
- std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl (
22
- std::unique_ptr<ir::Graph> graph) const {
23
- GraphPatternDetector gpd;
24
- auto * pattern = gpd.mutable_pattern ();
25
-
26
- std::unordered_set<int > fused_ops ({// first lstm
27
- 13 , 15 , 16 ,
28
- // second lstm
29
- 23 , 25 , 26 });
30
-
31
- pattern->NewNode ([&](Node* x) { return fused_ops.count (x->id ()); },
32
- " any_node" );
22
+ std::string GenNodeName (const std::string& prefix, const std::string& name) {
23
+ return prefix + " /" + name;
24
+ }
33
25
34
- std::unordered_set<Node*> marked_nodes;
26
+ void BuildPattern (PDPattern* pattern, const std::string& name_scope,
27
+ bool with_fc_bias) {
28
+ PDNode* x = pattern->NewNode (name_scope, " x" )
29
+ ->assert_is_op_input (" mul" )
30
+ ->assert_var_not_persistable ();
31
+ auto * fc_out = patterns::FC (pattern, name_scope, x, with_fc_bias);
32
+ fc_out->AsIntermediate (); // fc_out is a tmp var, will be removed after fuse.
33
+ patterns::LSTM (pattern, name_scope, fc_out);
34
+ // LOG(INFO) << "\n" << pattern->DotString();
35
+ }
35
36
36
- auto handler = [&](const GraphPatternDetector::subgraph_t & subgraph,
37
- Graph* g) {
37
+ int BuildFusion (Graph* graph, const std::string& name_scope, Scope* scope,
38
+ bool with_fc_bias) {
39
+ GraphPatternDetector gpd;
40
+ auto * pattern = gpd.mutable_pattern ();
38
41
39
- auto * id = subgraph.at (gpd.pattern ().RetrieveNode (" any_node" ));
40
- marked_nodes.insert (id);
41
- };
42
- gpd (graph.get (), handler);
42
+ BuildPattern (pattern, name_scope, with_fc_bias);
43
43
44
44
// Create New OpDesc
45
45
auto lstm_creator = [&](int lstm, int input, int weight_x, int weight_h,
46
- int bias, int hidden, int cell, int xx) {
46
+ int bias, int hidden, int cell, int xx, int fc_bias ) {
47
47
#define GET_NODE (x ) auto * x##_n = graph->RetriveNode (x);
48
48
GET_NODE (input);
49
49
GET_NODE (weight_x);
@@ -61,12 +61,33 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
61
61
SET_IN (WeightX, weight_x);
62
62
SET_IN (WeightH, weight_h);
63
63
SET_IN (Bias, bias);
64
- #undef GET_NODE
65
64
#undef SET_IN
65
+ if (with_fc_bias) {
66
+ // Add FC-bias with LSTM-bias and create a new weight
67
+ PADDLE_ENFORCE (scope);
68
+ const std::string& new_bias_var = name_scope + " _bias.new" ;
69
+ auto * bias_var = scope->Var (new_bias_var);
70
+ PADDLE_ENFORCE (bias_var);
71
+ auto * bias_tensor = bias_var->GetMutable <framework::LoDTensor>();
72
+ auto * lstm_bias_var = scope->FindVar (bias_n->Name ());
73
+ PADDLE_ENFORCE (lstm_bias_var);
74
+ const auto & lstm_bias_tensor = lstm_bias_var->Get <framework::LoDTensor>();
75
+ bias_tensor->Resize (lstm_bias_tensor.dims ());
76
+
77
+ GET_NODE (fc_bias);
78
+ auto * fc_bias_var = scope->FindVar (fc_bias_n->Name ());
79
+ const auto & fc_bias_tensor = fc_bias_var->Get <framework::LoDTensor>();
80
+
81
+ auto * data = bias_tensor->mutable_data <float >(platform::CPUPlace ());
82
+
83
+ for (int i = 0 ; i < bias_tensor->numel (); i++) {
84
+ data[i] =
85
+ fc_bias_tensor.data <float >()[i] + lstm_bias_tensor.data <float >()[i];
86
+ }
87
+ op_desc.SetInput (" Bias" , {new_bias_var});
88
+ }
66
89
67
- VLOG (4 ) << " hidden_n: " << hidden_n->Name ();
68
- VLOG (4 ) << " cell: " << cell_n->Name ();
69
- VLOG (4 ) << " xx: " << xx_n->Name ();
90
+ #undef GET_NODE
70
91
71
92
op_desc.SetInput (" H0" , {});
72
93
op_desc.SetInput (" C0" , {});
@@ -76,7 +97,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
76
97
op_desc.SetOutput (" BatchedGate" , {" blstm_0.tmp_2" });
77
98
op_desc.SetOutput (" BatchCellPreAct" , {" blstm_1.tmp_2" });
78
99
op_desc.SetAttr (" is_reverse" , lstm_n->Op ()->GetAttr (" is_reverse" ));
79
- op_desc.SetAttr (" use_peepholes" , false );
100
+ op_desc.SetAttr (" use_peepholes" , lstm_n-> Op ()-> GetAttr ( " use_peepholes " ) );
80
101
auto * op = graph->CreateOpNode (&op_desc);
81
102
82
103
#define LINK_TO (a, b ) \
@@ -89,38 +110,77 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
89
110
LINK_TO (op, hidden_n);
90
111
#undef LINK_TO
91
112
return op;
92
-
93
113
};
94
114
95
- lstm_creator (16 , 12 , 14 , 18 , 17 , 22 , 21 , 19 );
96
- lstm_creator (26 , 12 , 24 , 28 , 27 , 32 , 31 , 29 );
115
+ int fusion_count{0 };
97
116
98
- // remove all the nodes
117
+ auto fc_no_bias_handler = [&](
118
+ const GraphPatternDetector::subgraph_t & subgraph, Graph* g) {
99
119
100
- for (auto * node : marked_nodes) {
101
- graph->RemoveNode (const_cast <Node*>(node));
102
- }
120
+ #define GET_NODE (name__ ) \
121
+ std::string name__##key = name_scope + " /" + #name__; \
122
+ auto * name__##n = pattern->RetrieveNode (name__##key); \
123
+ PADDLE_ENFORCE (name__##n); \
124
+ PADDLE_ENFORCE (subgraph.count (name__##n)); \
125
+ Node* name__##_n = subgraph.at (name__##n); \
126
+ int name__ __attribute__ ((unused)) = name__##_n->id ();
103
127
104
- for (auto * node : graph->Nodes ()) {
105
- for (auto it = node->inputs .begin (); it != node->inputs .end ();) {
106
- if (marked_nodes.count (*it)) {
107
- it = const_cast <Node*>(node)->inputs .erase (it);
108
- } else
109
- it++;
110
- }
111
- for (auto it = node->outputs .begin (); it != node->outputs .end ();) {
112
- if (marked_nodes.count (*it)) {
113
- it = const_cast <Node*>(node)->outputs .erase (it);
114
- } else
115
- it++;
128
+ GET_NODE (x);
129
+ GET_NODE (w);
130
+ GET_NODE (mul);
131
+ GET_NODE (fc_out);
132
+ GET_NODE (Weight);
133
+ GET_NODE (lstm);
134
+ GET_NODE (Bias);
135
+ GET_NODE (Hidden);
136
+ GET_NODE (Cell);
137
+
138
+ if (with_fc_bias) {
139
+ GET_NODE (fc_bias);
140
+ lstm_creator (lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, fc_bias);
141
+ } else {
142
+ lstm_creator (lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, -1 );
116
143
}
117
- }
144
+ #undef GET_NODE
145
+
146
+ // Remove unneeded nodes.
147
+ std::unordered_set<const Node*> marked_nodes ({mul_n, lstm_n});
148
+
149
+ GraphSafeRemoveNodes (graph, marked_nodes);
150
+
151
+ ++fusion_count;
152
+ };
153
+
154
+ gpd (graph, fc_no_bias_handler);
155
+
156
+ return fusion_count;
157
+ }
158
+
159
+ std::unique_ptr<ir::Graph> MulLstmFusePass::ApplyImpl (
160
+ std::unique_ptr<ir::Graph> graph) const {
161
+ FusePassBase::Init (name_scope_, graph.get ());
162
+
163
+ int fusion_count = BuildFusion (graph.get (), name_scope_, param_scope (),
164
+ false /* with_fc_bias*/ );
165
+
166
+ AddStatis (fusion_count);
167
+ return graph;
168
+ }
169
+
170
+ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl (
171
+ std::unique_ptr<ir::Graph> graph) const {
172
+ FusePassBase::Init (name_scope_, graph.get ());
173
+
174
+ int fusion_count = BuildFusion (graph.get (), name_scope_, param_scope (),
175
+ true /* with_fc_bias*/ );
118
176
177
+ AddStatis (fusion_count);
119
178
return graph;
120
179
}
121
180
122
181
} // namespace ir
123
182
} // namespace framework
124
183
} // namespace paddle
125
184
185
+ REGISTER_PASS (mul_lstm_fuse_pass, paddle::framework::ir::MulLstmFusePass);
126
186
REGISTER_PASS (fc_lstm_fuse_pass, paddle::framework::ir::FCLstmFusePass);
0 commit comments