@@ -99,10 +99,9 @@ bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
99
99
return false ;
100
100
}
101
101
102
- std::pair<bool , Node*> ResidualConnectionMKLDNNFusePass::HasBias (
103
- const Node& op) const {
102
+ std::pair<bool , Node*> HasBias (const Node& op, const std::string& bias_name) {
104
103
auto bias_input_names = op.Op ()->Inputs ();
105
- auto bias_it = bias_input_names.find (" Bias " );
104
+ auto bias_it = bias_input_names.find (bias_name );
106
105
107
106
if (bias_it != std::end (bias_input_names)) {
108
107
bool has_bias = !bias_it->second .empty ();
@@ -121,6 +120,74 @@ std::pair<bool, Node*> ResidualConnectionMKLDNNFusePass::HasBias(
121
120
return std::make_pair (false , nullptr );
122
121
}
123
122
123
+ ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler (
124
+ const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv_op,
125
+ const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc&
126
+ get_node_from_elementwise_add_op,
127
+ const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func)
128
+ : get_node_from_conv_op{get_node_from_conv_op},
129
+ get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
130
+ can_fuse_func{can_fuse_func} {}
131
+
132
+ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator ()(
133
+ const GraphPatternDetector::subgraph_t & subgraph, Graph* graph) {
134
+ Node* conv_op;
135
+ Node* conv_input;
136
+ Node* conv_filter;
137
+ Node* conv_output;
138
+
139
+ Node* elementwise_add_op;
140
+ Node* elementwise_add_identity;
141
+ Node* elementwise_add_out;
142
+
143
+ std::tie (conv_op, conv_input, conv_filter, conv_output) =
144
+ get_node_from_conv_op (subgraph);
145
+ std::tie (elementwise_add_op, elementwise_add_identity, elementwise_add_out) =
146
+ get_node_from_elementwise_add_op (subgraph);
147
+
148
+ if (!can_fuse_func (conv_op, elementwise_add_op)) return ;
149
+
150
+ if (!IsReachable (graph, elementwise_add_identity, conv_output)) return ;
151
+
152
+ OpDesc op_desc;
153
+ op_desc.SetType (" conv2d" );
154
+
155
+ op_desc.SetInput (" Input" , {conv_input->Name ()});
156
+ op_desc.SetInput (" Filter" , {conv_filter->Name ()});
157
+ op_desc.SetInput (" ResidualData" , {elementwise_add_identity->Name ()});
158
+ op_desc.SetOutput (" Output" , {conv_output->Name ()});
159
+
160
+ bool has_bias;
161
+ Node* conv_bias;
162
+
163
+ std::tie (has_bias, conv_bias) = HasBias (*conv_op, " Bias" );
164
+
165
+ if (has_bias) {
166
+ op_desc.SetInput (" Bias" , {conv_bias->Name ()});
167
+ }
168
+
169
+ for (const auto & attr : conv_op->Op ()->GetAttrMap ()) {
170
+ op_desc.SetAttr (attr.first , attr.second );
171
+ }
172
+
173
+ op_desc.SetAttr (" fuse_residual_connection" , true );
174
+
175
+ auto fused_conv_op = graph->CreateOpNode (&op_desc);
176
+
177
+ IR_NODE_LINK_TO (conv_input, fused_conv_op);
178
+ IR_NODE_LINK_TO (conv_filter, fused_conv_op);
179
+ IR_NODE_LINK_TO (elementwise_add_identity, fused_conv_op);
180
+ IR_NODE_LINK_TO (fused_conv_op, conv_output);
181
+
182
+ if (has_bias) {
183
+ IR_NODE_LINK_TO (conv_bias, fused_conv_op);
184
+ }
185
+
186
+ CorrectGraphEdges (graph, elementwise_add_out, conv_output);
187
+ GraphSafeRemoveNodes (graph,
188
+ {elementwise_add_out, conv_op, elementwise_add_op});
189
+ }
190
+
124
191
graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX (
125
192
const std::string& name_scope_, graph_ptr graph) const {
126
193
GraphPatternDetector gpd;
@@ -135,8 +202,8 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
135
202
pattern->NewNode (elementwise_add_pattern.elementwise_add_y_repr ()));
136
203
conv_output->AsIntermediate ();
137
204
138
- auto get_node_from_conv = []( const patterns::Conv& conv_pattern,
139
- const GraphPatternDetector::subgraph_t & subgraph)
205
+ auto get_node_from_conv =
206
+ [&conv_pattern]( const GraphPatternDetector::subgraph_t & subgraph)
140
207
-> std::tuple<Node*, Node*, Node*, Node*> {
141
208
GET_IR_NODE_FROM_SUBGRAPH (conv_op, conv_op, conv_pattern);
142
209
GET_IR_NODE_FROM_SUBGRAPH (conv_input, conv_input, conv_pattern);
@@ -146,8 +213,7 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
146
213
return std::make_tuple (conv_op, conv_input, conv_filter, conv_output);
147
214
};
148
215
149
- auto get_node_from_elementwise_add = [](
150
- const patterns::ElementwiseAdd& elementwise_add_pattern,
216
+ auto get_node_from_elementwise_add = [&elementwise_add_pattern](
151
217
const GraphPatternDetector::subgraph_t & subgraph)
152
218
-> std::tuple<Node*, Node*, Node*> {
153
219
GET_IR_NODE_FROM_SUBGRAPH (elementwise_add_op, elementwise_add_op,
@@ -161,10 +227,14 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
161
227
elementwise_add_out);
162
228
};
163
229
164
- auto handler =
165
- GenerateFuseHandler (conv_pattern, elementwise_add_pattern,
166
- get_node_from_conv, get_node_from_elementwise_add);
167
- gpd (graph.get (), handler);
230
+ auto can_fuse = [this ](Node* op1, Node* op2) -> bool {
231
+ return this ->FindFuseOption (*op1, *op2) == FUSE_MKLDNN;
232
+ };
233
+
234
+ auto fuse_handler =
235
+ FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse};
236
+
237
+ gpd (graph.get (), fuse_handler);
168
238
169
239
return graph;
170
240
}
@@ -183,8 +253,8 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
183
253
conv_output);
184
254
conv_output->AsIntermediate ();
185
255
186
- auto get_node_from_conv = []( const patterns::Conv& conv_pattern,
187
- const GraphPatternDetector::subgraph_t & subgraph)
256
+ auto get_node_from_conv =
257
+ [&conv_pattern]( const GraphPatternDetector::subgraph_t & subgraph)
188
258
-> std::tuple<Node*, Node*, Node*, Node*> {
189
259
GET_IR_NODE_FROM_SUBGRAPH (conv_op, conv_op, conv_pattern);
190
260
GET_IR_NODE_FROM_SUBGRAPH (conv_input, conv_input, conv_pattern);
@@ -194,8 +264,7 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
194
264
return std::make_tuple (conv_op, conv_input, conv_filter, conv_output);
195
265
};
196
266
197
- auto get_node_from_elementwise_add = [](
198
- const patterns::ElementwiseAdd& elementwise_add_pattern,
267
+ auto get_node_from_elementwise_add = [&elementwise_add_pattern](
199
268
const GraphPatternDetector::subgraph_t & subgraph)
200
269
-> std::tuple<Node*, Node*, Node*> {
201
270
GET_IR_NODE_FROM_SUBGRAPH (elementwise_add_op, elementwise_add_op,
@@ -209,10 +278,14 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
209
278
elementwise_add_out);
210
279
};
211
280
212
- auto handler =
213
- GenerateFuseHandler (conv_pattern, elementwise_add_pattern,
214
- get_node_from_conv, get_node_from_elementwise_add);
215
- gpd (graph.get (), handler);
281
+ auto can_fuse = [this ](Node* op1, Node* op2) -> bool {
282
+ return this ->FindFuseOption (*op1, *op2) == FUSE_MKLDNN;
283
+ };
284
+
285
+ auto fuse_handler =
286
+ FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse};
287
+
288
+ gpd (graph.get (), fuse_handler);
216
289
217
290
return graph;
218
291
}
0 commit comments