@@ -120,17 +120,18 @@ boost::optional<Node*> HasBias(const Node& op, const std::string& bias_name) {
120
120
return boost::none;
121
121
}
122
122
123
- ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler (
124
- const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv_op,
125
- const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc&
126
- get_node_from_elementwise_add_op,
127
- const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func)
123
+ ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::IdentityFuseHandle (
124
+ const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func,
125
+ const ResidualConnectionMKLDNNFusePass::IdentityConvFunc&
126
+ get_node_from_conv_op,
127
+ const ResidualConnectionMKLDNNFusePass::IdentityElementwiseAddFunc&
128
+ get_node_from_elementwise_add_op)
128
129
: fusion_stats{std::make_shared<int >(0 )},
130
+ can_fuse_func{can_fuse_func},
129
131
get_node_from_conv_op{get_node_from_conv_op},
130
- get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
131
- can_fuse_func{can_fuse_func} {}
132
+ get_node_from_elementwise_add_op{get_node_from_elementwise_add_op} {}
132
133
133
- void ResidualConnectionMKLDNNFusePass::FuseHandler ::operator ()(
134
+ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle ::operator ()(
134
135
const GraphPatternDetector::subgraph_t & subgraph, Graph* graph) {
135
136
Node* conv_op;
136
137
Node* conv_input;
@@ -187,6 +188,104 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
187
188
(*fusion_stats)++;
188
189
}
189
190
191
+ ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::ProjectionFuseHandle (
192
+ const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func,
193
+ const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc&
194
+ get_node_from_conv_x_op,
195
+ const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc&
196
+ get_node_from_conv_y_op,
197
+ const ResidualConnectionMKLDNNFusePass::ProjectionElementwiseAddFunc&
198
+ get_node_from_elementwise_add_op)
199
+ : fusion_stats{std::make_shared<int >(0 )},
200
+ can_fuse_func{can_fuse_func},
201
+ get_node_from_conv_x_op{get_node_from_conv_x_op},
202
+ get_node_from_conv_y_op{get_node_from_conv_y_op},
203
+ get_node_from_elementwise_add_op{get_node_from_elementwise_add_op} {}
204
+
205
+ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator ()(
206
+ const GraphPatternDetector::subgraph_t & subgraph, Graph* graph) {
207
+ Node* conv_x_op;
208
+ Node* conv_x_input;
209
+ Node* conv_x_filter;
210
+ Node* conv_x_output;
211
+
212
+ Node* conv_y_op;
213
+ Node* conv_y_input;
214
+ Node* conv_y_filter;
215
+ Node* conv_y_output;
216
+
217
+ Node* elementwise_add_op;
218
+ Node* elementwise_add_out;
219
+
220
+ std::tie (conv_x_op, conv_x_input, conv_x_filter, conv_x_output) =
221
+ get_node_from_conv_x_op (subgraph);
222
+ std::tie (conv_y_op, conv_y_input, conv_y_filter, conv_y_output) =
223
+ get_node_from_conv_y_op (subgraph);
224
+ std::tie (elementwise_add_op, elementwise_add_out) =
225
+ get_node_from_elementwise_add_op (subgraph);
226
+
227
+ if (!can_fuse_func (conv_x_op, elementwise_add_op)) return ;
228
+ if (!can_fuse_func (conv_y_op, elementwise_add_op)) return ;
229
+
230
+ Node* projection_node;
231
+ Node* residual_conv_op;
232
+ Node* residual_conv_input;
233
+ Node* residual_conv_filter;
234
+ Node* residual_conv_output;
235
+
236
+ if (IsReachable (graph, conv_x_input, conv_y_output)) {
237
+ projection_node = conv_x_output;
238
+ residual_conv_op = conv_y_op;
239
+ residual_conv_input = conv_y_input;
240
+ residual_conv_filter = conv_y_filter;
241
+ residual_conv_output = conv_y_output;
242
+ } else if (IsReachable (graph, conv_y_input, conv_x_output)) {
243
+ projection_node = conv_y_output;
244
+ residual_conv_op = conv_x_op;
245
+ residual_conv_input = conv_x_input;
246
+ residual_conv_filter = conv_x_filter;
247
+ residual_conv_output = conv_x_output;
248
+ } else {
249
+ return ;
250
+ }
251
+
252
+ OpDesc op_desc;
253
+ op_desc.SetType (" conv2d" );
254
+
255
+ op_desc.SetInput (" Input" , {residual_conv_input->Name ()});
256
+ op_desc.SetInput (" Filter" , {residual_conv_filter->Name ()});
257
+ op_desc.SetInput (" ResidualData" , {projection_node->Name ()});
258
+ op_desc.SetOutput (" Output" , {residual_conv_output->Name ()});
259
+
260
+ auto residual_conv_bias = HasBias (*residual_conv_op, " Bias" );
261
+
262
+ if (residual_conv_bias) {
263
+ op_desc.SetInput (" Bias" , {(*residual_conv_bias)->Name ()});
264
+ }
265
+
266
+ for (const auto & attr : residual_conv_op->Op ()->GetAttrMap ()) {
267
+ op_desc.SetAttr (attr.first , attr.second );
268
+ }
269
+
270
+ op_desc.SetAttr (" fuse_residual_connection" , true );
271
+
272
+ auto fused_conv_op = graph->CreateOpNode (&op_desc);
273
+
274
+ IR_NODE_LINK_TO (residual_conv_input, fused_conv_op);
275
+ IR_NODE_LINK_TO (residual_conv_filter, fused_conv_op);
276
+ IR_NODE_LINK_TO (projection_node, fused_conv_op);
277
+ IR_NODE_LINK_TO (fused_conv_op, residual_conv_output);
278
+
279
+ if (residual_conv_bias) {
280
+ IR_NODE_LINK_TO ((*residual_conv_bias), fused_conv_op);
281
+ }
282
+
283
+ CorrectGraphEdges (graph, elementwise_add_out, residual_conv_output);
284
+ GraphSafeRemoveNodes (
285
+ graph, {elementwise_add_out, residual_conv_op, elementwise_add_op});
286
+ (*fusion_stats)++;
287
+ }
288
+
190
289
std::tuple<Node*, Node*, Node*, Node*>
191
290
ResidualConnectionMKLDNNFusePass::GetNodesFromConv (
192
291
const patterns::Conv& conv_pattern,
@@ -233,7 +332,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
233
332
elementwise_add_out);
234
333
};
235
334
236
- return ExecuteHandlerOnGraph (
335
+ return ExecuteHandleOnGraph<IdentityFuseHandle> (
237
336
&gpd, graph_with_stats,
238
337
[this , &conv_pattern](const GraphPatternDetector::subgraph_t & subgraph) {
239
338
return GetNodesFromConv (conv_pattern, subgraph);
@@ -270,41 +369,62 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
270
369
elementwise_add_out);
271
370
};
272
371
273
- return ExecuteHandlerOnGraph (
372
+ return ExecuteHandleOnGraph<IdentityFuseHandle> (
274
373
&gpd, graph_with_stats,
275
374
[this , &conv_pattern](const GraphPatternDetector::subgraph_t & subgraph) {
276
375
return GetNodesFromConv (conv_pattern, subgraph);
277
376
},
278
377
get_node_from_elementwise_add);
279
378
}
280
379
281
- GraphWithStats ResidualConnectionMKLDNNFusePass::ExecuteHandlerOnGraph (
282
- GraphPatternDetector* gpd, const GraphWithStats& graph_with_stats,
283
- const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv,
284
- const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc&
285
- get_node_from_elementwise_add) const {
286
- ir::Graph* graph;
287
- int stats;
380
+ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv (
381
+ const std::string& name_scope,
382
+ const GraphWithStats& graph_with_stats) const {
383
+ GraphPatternDetector gpd;
384
+ auto pattern = gpd.mutable_pattern ();
288
385
289
- std::tie (graph, stats) = graph_with_stats;
386
+ patterns::Conv conv_x_pattern{pattern, name_scope};
387
+ auto conv_x_output = conv_x_pattern ();
290
388
291
- auto can_fuse = [this ](Node* op1, Node* op2) -> bool {
292
- return this ->FindFuseOption (*op1, *op2) == FUSE_MKLDNN;
293
- };
389
+ patterns::Conv conv_y_pattern{pattern, name_scope};
390
+ auto conv_y_output = conv_y_pattern ();
294
391
295
- auto fuse_handler =
296
- FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse};
392
+ patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope};
393
+ elementwise_add_pattern (conv_x_output, conv_y_output);
394
+ conv_x_output->AsIntermediate ();
395
+ conv_y_output->AsIntermediate ();
297
396
298
- (*gpd)(graph, fuse_handler);
397
+ auto get_node_from_elementwise_add = [&elementwise_add_pattern](
398
+ const GraphPatternDetector::subgraph_t & subgraph)
399
+ -> std::tuple<Node*, Node*> {
400
+ GET_IR_NODE_FROM_SUBGRAPH (elementwise_add_op, elementwise_add_op,
401
+ elementwise_add_pattern);
402
+ GET_IR_NODE_FROM_SUBGRAPH (elementwise_add_out, elementwise_add_out,
403
+ elementwise_add_pattern);
299
404
300
- return std::make_pair (graph, stats + fuse_handler.get_stats ());
405
+ return std::make_tuple (elementwise_add_op, elementwise_add_out);
406
+ };
407
+
408
+ return ExecuteHandleOnGraph<ProjectionFuseHandle>(
409
+ &gpd, graph_with_stats,
410
+ [this ,
411
+ &conv_x_pattern](const GraphPatternDetector::subgraph_t & subgraph) {
412
+ return GetNodesFromConv (conv_x_pattern, subgraph);
413
+ },
414
+ [this ,
415
+ &conv_y_pattern](const GraphPatternDetector::subgraph_t & subgraph) {
416
+ return GetNodesFromConv (conv_y_pattern, subgraph);
417
+ },
418
+ get_node_from_elementwise_add);
301
419
}
302
420
303
421
graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl (graph_ptr graph) const {
304
422
FusePassBase::Init (name_scope_, graph.get ());
305
-
306
423
auto fused_graph_with_stats = FuseConvAsY (
307
- name_scope_, FuseConvAsX (name_scope_, std::make_pair (graph.get (), 0 )));
424
+ name_scope_,
425
+ FuseConvAsX (
426
+ name_scope_,
427
+ FuseProjectionConv (name_scope_, std::make_pair (graph.get (), 0 ))));
308
428
309
429
std::cout << " Fused graph " << fused_graph_with_stats.second << std::endl;
310
430
AddStatis (fused_graph_with_stats.second );
0 commit comments