@@ -99,7 +99,7 @@ bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
99
99
return false ;
100
100
}
101
101
102
- std::pair< bool , Node*> HasBias (const Node& op, const std::string& bias_name) {
102
+ boost::optional< Node*> HasBias (const Node& op, const std::string& bias_name) {
103
103
auto bias_input_names = op.Op ()->Inputs ();
104
104
auto bias_it = bias_input_names.find (bias_name);
105
105
@@ -113,19 +113,20 @@ std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name) {
113
113
[&bias_names](Node* n) -> bool {
114
114
return n->Name () == bias_names[0 ];
115
115
});
116
- return std::make_pair (has_bias, *bias_names_it) ;
116
+ return *bias_names_it;
117
117
}
118
118
}
119
119
120
- return std::make_pair ( false , nullptr ) ;
120
+ return boost::none ;
121
121
}
122
122
123
123
ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler (
124
124
const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv_op,
125
125
const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc&
126
126
get_node_from_elementwise_add_op,
127
127
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func)
128
- : get_node_from_conv_op{get_node_from_conv_op},
128
+ : fusion_stats{std::make_shared<int >(0 )},
129
+ get_node_from_conv_op{get_node_from_conv_op},
129
130
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
130
131
can_fuse_func{can_fuse_func} {}
131
132
@@ -157,13 +158,10 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
157
158
op_desc.SetInput (" ResidualData" , {elementwise_add_identity->Name ()});
158
159
op_desc.SetOutput (" Output" , {conv_output->Name ()});
159
160
160
- bool has_bias;
161
- Node* conv_bias;
161
+ auto conv_bias = HasBias (*conv_op, " Bias" );
162
162
163
- std::tie (has_bias, conv_bias) = HasBias (*conv_op, " Bias" );
164
-
165
- if (has_bias) {
166
- op_desc.SetInput (" Bias" , {conv_bias->Name ()});
163
+ if (conv_bias) {
164
+ op_desc.SetInput (" Bias" , {(*conv_bias)->Name ()});
167
165
}
168
166
169
167
for (const auto & attr : conv_op->Op ()->GetAttrMap ()) {
@@ -179,40 +177,48 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
179
177
IR_NODE_LINK_TO (elementwise_add_identity, fused_conv_op);
180
178
IR_NODE_LINK_TO (fused_conv_op, conv_output);
181
179
182
- if (has_bias ) {
183
- IR_NODE_LINK_TO (conv_bias, fused_conv_op);
180
+ if (conv_bias ) {
181
+ IR_NODE_LINK_TO ((* conv_bias) , fused_conv_op);
184
182
}
185
183
186
184
CorrectGraphEdges (graph, elementwise_add_out, conv_output);
187
185
GraphSafeRemoveNodes (graph,
188
186
{elementwise_add_out, conv_op, elementwise_add_op});
187
+ (*fusion_stats)++;
188
+ }
189
+
190
+ std::tuple<Node*, Node*, Node*, Node*>
191
+ ResidualConnectionMKLDNNFusePass::GetNodesFromConv (
192
+ const patterns::Conv& conv_pattern,
193
+ const GraphPatternDetector::subgraph_t & subgraph) const {
194
+ GET_IR_NODE_FROM_SUBGRAPH (conv_op, conv_op, conv_pattern);
195
+ GET_IR_NODE_FROM_SUBGRAPH (conv_input, conv_input, conv_pattern);
196
+ GET_IR_NODE_FROM_SUBGRAPH (conv_filter, conv_filter, conv_pattern);
197
+ GET_IR_NODE_FROM_SUBGRAPH (conv_output, conv_output, conv_pattern);
198
+
199
+ return std::make_tuple (conv_op, conv_input, conv_filter, conv_output);
189
200
}
190
201
191
- graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX (
192
- const std::string& name_scope_, graph_ptr graph) const {
202
+ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX (
203
+ const std::string& name_scope,
204
+ const GraphWithStats& graph_with_stats) const {
205
+ ir::Graph* graph;
206
+ int stats;
207
+
208
+ std::tie (graph, stats) = graph_with_stats;
209
+
193
210
GraphPatternDetector gpd;
194
211
auto pattern = gpd.mutable_pattern ();
195
212
196
- patterns::Conv conv_pattern{pattern, name_scope_ };
213
+ patterns::Conv conv_pattern{pattern, name_scope };
197
214
auto conv_output = conv_pattern ();
198
215
199
- patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_ };
216
+ patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope };
200
217
elementwise_add_pattern (
201
218
conv_output,
202
219
pattern->NewNode (elementwise_add_pattern.elementwise_add_y_repr ()));
203
220
conv_output->AsIntermediate ();
204
221
205
- auto get_node_from_conv =
206
- [&conv_pattern](const GraphPatternDetector::subgraph_t & subgraph)
207
- -> std::tuple<Node*, Node*, Node*, Node*> {
208
- GET_IR_NODE_FROM_SUBGRAPH (conv_op, conv_op, conv_pattern);
209
- GET_IR_NODE_FROM_SUBGRAPH (conv_input, conv_input, conv_pattern);
210
- GET_IR_NODE_FROM_SUBGRAPH (conv_filter, conv_filter, conv_pattern);
211
- GET_IR_NODE_FROM_SUBGRAPH (conv_output, conv_output, conv_pattern);
212
-
213
- return std::make_tuple (conv_op, conv_input, conv_filter, conv_output);
214
- };
215
-
216
222
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
217
223
const GraphPatternDetector::subgraph_t & subgraph)
218
224
-> std::tuple<Node*, Node*, Node*> {
@@ -227,43 +233,29 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
227
233
elementwise_add_out);
228
234
};
229
235
230
- auto can_fuse = [this ](Node* op1, Node* op2) -> bool {
231
- return this ->FindFuseOption (*op1, *op2) == FUSE_MKLDNN;
232
- };
233
-
234
- auto fuse_handler =
235
- FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse};
236
-
237
- gpd (graph.get (), fuse_handler);
238
-
239
- return graph;
236
+ return ExecuteHandlerOnGraph (
237
+ &gpd, graph_with_stats,
238
+ [this , &conv_pattern](const GraphPatternDetector::subgraph_t & subgraph) {
239
+ return GetNodesFromConv (conv_pattern, subgraph);
240
+ },
241
+ get_node_from_elementwise_add);
240
242
}
241
243
242
- graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY (
243
- const std::string& name_scope_, graph_ptr graph) const {
244
+ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY (
245
+ const std::string& name_scope,
246
+ const GraphWithStats& graph_with_stats) const {
244
247
GraphPatternDetector gpd;
245
248
auto pattern = gpd.mutable_pattern ();
246
249
247
- patterns::Conv conv_pattern{pattern, name_scope_ };
250
+ patterns::Conv conv_pattern{pattern, name_scope };
248
251
auto conv_output = conv_pattern ();
249
252
250
- patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_ };
253
+ patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope };
251
254
elementwise_add_pattern (
252
255
pattern->NewNode (elementwise_add_pattern.elementwise_add_x_repr ()),
253
256
conv_output);
254
257
conv_output->AsIntermediate ();
255
258
256
- auto get_node_from_conv =
257
- [&conv_pattern](const GraphPatternDetector::subgraph_t & subgraph)
258
- -> std::tuple<Node*, Node*, Node*, Node*> {
259
- GET_IR_NODE_FROM_SUBGRAPH (conv_op, conv_op, conv_pattern);
260
- GET_IR_NODE_FROM_SUBGRAPH (conv_input, conv_input, conv_pattern);
261
- GET_IR_NODE_FROM_SUBGRAPH (conv_filter, conv_filter, conv_pattern);
262
- GET_IR_NODE_FROM_SUBGRAPH (conv_output, conv_output, conv_pattern);
263
-
264
- return std::make_tuple (conv_op, conv_input, conv_filter, conv_output);
265
- };
266
-
267
259
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
268
260
const GraphPatternDetector::subgraph_t & subgraph)
269
261
-> std::tuple<Node*, Node*, Node*> {
@@ -278,22 +270,45 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
278
270
elementwise_add_out);
279
271
};
280
272
273
+ return ExecuteHandlerOnGraph (
274
+ &gpd, graph_with_stats,
275
+ [this , &conv_pattern](const GraphPatternDetector::subgraph_t & subgraph) {
276
+ return GetNodesFromConv (conv_pattern, subgraph);
277
+ },
278
+ get_node_from_elementwise_add);
279
+ }
280
+
281
+ GraphWithStats ResidualConnectionMKLDNNFusePass::ExecuteHandlerOnGraph (
282
+ GraphPatternDetector* gpd, const GraphWithStats& graph_with_stats,
283
+ const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv,
284
+ const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc&
285
+ get_node_from_elementwise_add) const {
286
+ ir::Graph* graph;
287
+ int stats;
288
+
289
+ std::tie (graph, stats) = graph_with_stats;
290
+
281
291
auto can_fuse = [this ](Node* op1, Node* op2) -> bool {
282
292
return this ->FindFuseOption (*op1, *op2) == FUSE_MKLDNN;
283
293
};
284
294
285
295
auto fuse_handler =
286
296
FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse};
287
297
288
- gpd (graph. get () , fuse_handler);
298
+ (* gpd) (graph, fuse_handler);
289
299
290
- return graph;
300
+ return std::make_pair ( graph, stats + fuse_handler. get_stats ()) ;
291
301
}
292
302
293
303
graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl (graph_ptr graph) const {
294
304
FusePassBase::Init (name_scope_, graph.get ());
295
305
296
- return FuseConvAsY (name_scope_, FuseConvAsX (name_scope_, std::move (graph)));
306
+ auto fused_graph_with_stats = FuseConvAsY (
307
+ name_scope_, FuseConvAsX (name_scope_, std::make_pair (graph.get (), 0 )));
308
+
309
+ std::cout << " Fused graph " << fused_graph_with_stats.second << std::endl;
310
+ AddStatis (fused_graph_with_stats.second );
311
+ return graph;
297
312
}
298
313
} // namespace ir
299
314
} // namespace framework
0 commit comments