Skip to content

Commit 53da846

Browse files
author
Tomasz Patejko
committed
MKLDNN residual connections fuse pass: initial implementation of fusion for projection pass
test=develop
1 parent dbc4fcd commit 53da846

File tree

2 files changed

+206
-39
lines changed

2 files changed

+206
-39
lines changed

paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc

Lines changed: 147 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,18 @@ boost::optional<Node*> HasBias(const Node& op, const std::string& bias_name) {
120120
return boost::none;
121121
}
122122

123-
ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler(
124-
const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv_op,
125-
const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc&
126-
get_node_from_elementwise_add_op,
127-
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func)
123+
ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::IdentityFuseHandle(
124+
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func,
125+
const ResidualConnectionMKLDNNFusePass::IdentityConvFunc&
126+
get_node_from_conv_op,
127+
const ResidualConnectionMKLDNNFusePass::IdentityElementwiseAddFunc&
128+
get_node_from_elementwise_add_op)
128129
: fusion_stats{std::make_shared<int>(0)},
130+
can_fuse_func{can_fuse_func},
129131
get_node_from_conv_op{get_node_from_conv_op},
130-
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
131-
can_fuse_func{can_fuse_func} {}
132+
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op} {}
132133

133-
void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
134+
void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
134135
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
135136
Node* conv_op;
136137
Node* conv_input;
@@ -187,6 +188,104 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
187188
(*fusion_stats)++;
188189
}
189190

191+
ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::ProjectionFuseHandle(
192+
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func,
193+
const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc&
194+
get_node_from_conv_x_op,
195+
const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc&
196+
get_node_from_conv_y_op,
197+
const ResidualConnectionMKLDNNFusePass::ProjectionElementwiseAddFunc&
198+
get_node_from_elementwise_add_op)
199+
: fusion_stats{std::make_shared<int>(0)},
200+
can_fuse_func{can_fuse_func},
201+
get_node_from_conv_x_op{get_node_from_conv_x_op},
202+
get_node_from_conv_y_op{get_node_from_conv_y_op},
203+
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op} {}
204+
205+
void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()(
206+
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
207+
Node* conv_x_op;
208+
Node* conv_x_input;
209+
Node* conv_x_filter;
210+
Node* conv_x_output;
211+
212+
Node* conv_y_op;
213+
Node* conv_y_input;
214+
Node* conv_y_filter;
215+
Node* conv_y_output;
216+
217+
Node* elementwise_add_op;
218+
Node* elementwise_add_out;
219+
220+
std::tie(conv_x_op, conv_x_input, conv_x_filter, conv_x_output) =
221+
get_node_from_conv_x_op(subgraph);
222+
std::tie(conv_y_op, conv_y_input, conv_y_filter, conv_y_output) =
223+
get_node_from_conv_y_op(subgraph);
224+
std::tie(elementwise_add_op, elementwise_add_out) =
225+
get_node_from_elementwise_add_op(subgraph);
226+
227+
if (!can_fuse_func(conv_x_op, elementwise_add_op)) return;
228+
if (!can_fuse_func(conv_y_op, elementwise_add_op)) return;
229+
230+
Node* projection_node;
231+
Node* residual_conv_op;
232+
Node* residual_conv_input;
233+
Node* residual_conv_filter;
234+
Node* residual_conv_output;
235+
236+
if (IsReachable(graph, conv_x_input, conv_y_output)) {
237+
projection_node = conv_x_output;
238+
residual_conv_op = conv_y_op;
239+
residual_conv_input = conv_y_input;
240+
residual_conv_filter = conv_y_filter;
241+
residual_conv_output = conv_y_output;
242+
} else if (IsReachable(graph, conv_y_input, conv_x_output)) {
243+
projection_node = conv_y_output;
244+
residual_conv_op = conv_x_op;
245+
residual_conv_input = conv_x_input;
246+
residual_conv_filter = conv_x_filter;
247+
residual_conv_output = conv_x_output;
248+
} else {
249+
return;
250+
}
251+
252+
OpDesc op_desc;
253+
op_desc.SetType("conv2d");
254+
255+
op_desc.SetInput("Input", {residual_conv_input->Name()});
256+
op_desc.SetInput("Filter", {residual_conv_filter->Name()});
257+
op_desc.SetInput("ResidualData", {projection_node->Name()});
258+
op_desc.SetOutput("Output", {residual_conv_output->Name()});
259+
260+
auto residual_conv_bias = HasBias(*residual_conv_op, "Bias");
261+
262+
if (residual_conv_bias) {
263+
op_desc.SetInput("Bias", {(*residual_conv_bias)->Name()});
264+
}
265+
266+
for (const auto& attr : residual_conv_op->Op()->GetAttrMap()) {
267+
op_desc.SetAttr(attr.first, attr.second);
268+
}
269+
270+
op_desc.SetAttr("fuse_residual_connection", true);
271+
272+
auto fused_conv_op = graph->CreateOpNode(&op_desc);
273+
274+
IR_NODE_LINK_TO(residual_conv_input, fused_conv_op);
275+
IR_NODE_LINK_TO(residual_conv_filter, fused_conv_op);
276+
IR_NODE_LINK_TO(projection_node, fused_conv_op);
277+
IR_NODE_LINK_TO(fused_conv_op, residual_conv_output);
278+
279+
if (residual_conv_bias) {
280+
IR_NODE_LINK_TO((*residual_conv_bias), fused_conv_op);
281+
}
282+
283+
CorrectGraphEdges(graph, elementwise_add_out, residual_conv_output);
284+
GraphSafeRemoveNodes(
285+
graph, {elementwise_add_out, residual_conv_op, elementwise_add_op});
286+
(*fusion_stats)++;
287+
}
288+
190289
std::tuple<Node*, Node*, Node*, Node*>
191290
ResidualConnectionMKLDNNFusePass::GetNodesFromConv(
192291
const patterns::Conv& conv_pattern,
@@ -233,7 +332,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
233332
elementwise_add_out);
234333
};
235334

236-
return ExecuteHandlerOnGraph(
335+
return ExecuteHandleOnGraph<IdentityFuseHandle>(
237336
&gpd, graph_with_stats,
238337
[this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
239338
return GetNodesFromConv(conv_pattern, subgraph);
@@ -270,41 +369,62 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
270369
elementwise_add_out);
271370
};
272371

273-
return ExecuteHandlerOnGraph(
372+
return ExecuteHandleOnGraph<IdentityFuseHandle>(
274373
&gpd, graph_with_stats,
275374
[this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
276375
return GetNodesFromConv(conv_pattern, subgraph);
277376
},
278377
get_node_from_elementwise_add);
279378
}
280379

281-
GraphWithStats ResidualConnectionMKLDNNFusePass::ExecuteHandlerOnGraph(
282-
GraphPatternDetector* gpd, const GraphWithStats& graph_with_stats,
283-
const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv,
284-
const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc&
285-
get_node_from_elementwise_add) const {
286-
ir::Graph* graph;
287-
int stats;
380+
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
381+
const std::string& name_scope,
382+
const GraphWithStats& graph_with_stats) const {
383+
GraphPatternDetector gpd;
384+
auto pattern = gpd.mutable_pattern();
288385

289-
std::tie(graph, stats) = graph_with_stats;
386+
patterns::Conv conv_x_pattern{pattern, name_scope};
387+
auto conv_x_output = conv_x_pattern();
290388

291-
auto can_fuse = [this](Node* op1, Node* op2) -> bool {
292-
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
293-
};
389+
patterns::Conv conv_y_pattern{pattern, name_scope};
390+
auto conv_y_output = conv_y_pattern();
294391

295-
auto fuse_handler =
296-
FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse};
392+
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope};
393+
elementwise_add_pattern(conv_x_output, conv_y_output);
394+
conv_x_output->AsIntermediate();
395+
conv_y_output->AsIntermediate();
297396

298-
(*gpd)(graph, fuse_handler);
397+
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
398+
const GraphPatternDetector::subgraph_t& subgraph)
399+
-> std::tuple<Node*, Node*> {
400+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
401+
elementwise_add_pattern);
402+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
403+
elementwise_add_pattern);
299404

300-
return std::make_pair(graph, stats + fuse_handler.get_stats());
405+
return std::make_tuple(elementwise_add_op, elementwise_add_out);
406+
};
407+
408+
return ExecuteHandleOnGraph<ProjectionFuseHandle>(
409+
&gpd, graph_with_stats,
410+
[this,
411+
&conv_x_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
412+
return GetNodesFromConv(conv_x_pattern, subgraph);
413+
},
414+
[this,
415+
&conv_y_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
416+
return GetNodesFromConv(conv_y_pattern, subgraph);
417+
},
418+
get_node_from_elementwise_add);
301419
}
302420

303421
graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
304422
FusePassBase::Init(name_scope_, graph.get());
305-
306423
auto fused_graph_with_stats = FuseConvAsY(
307-
name_scope_, FuseConvAsX(name_scope_, std::make_pair(graph.get(), 0)));
424+
name_scope_,
425+
FuseConvAsX(
426+
name_scope_,
427+
FuseProjectionConv(name_scope_, std::make_pair(graph.get(), 0))));
308428

309429
std::cout << "Fused graph " << fused_graph_with_stats.second << std::endl;
310430
AddStatis(fused_graph_with_stats.second);

paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,37 +40,84 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
4040
const GraphWithStats& graph_with_stats) const;
4141
GraphWithStats FuseConvAsY(const std::string& name_scope,
4242
const GraphWithStats& graph_with_stats) const;
43+
GraphWithStats FuseProjectionConv(
44+
const std::string& name_scope,
45+
const GraphWithStats& graph_with_stats) const;
4346

4447
template <typename RetType>
4548
using GetNodeFunc =
4649
std::function<RetType(const GraphPatternDetector::subgraph_t& subgraph)>;
47-
using ConvFunc = GetNodeFunc<std::tuple<Node*, Node*, Node*, Node*>>;
48-
using ElementwiseAddFunc = GetNodeFunc<std::tuple<Node*, Node*, Node*>>;
50+
using IdentityConvFunc = GetNodeFunc<std::tuple<Node*, Node*, Node*, Node*>>;
51+
using IdentityElementwiseAddFunc =
52+
GetNodeFunc<std::tuple<Node*, Node*, Node*>>;
53+
54+
using ProjectionConvFunc = IdentityConvFunc;
55+
using ProjectionElementwiseAddFunc = GetNodeFunc<std::tuple<Node*, Node*>>;
56+
4957
using CanFuseFunc = std::function<bool(Node*, Node*)>;
5058

5159
std::tuple<Node*, Node*, Node*, Node*> GetNodesFromConv(
5260
const patterns::Conv& conv_pattern,
5361
const GraphPatternDetector::subgraph_t& subgraph) const;
5462

55-
GraphWithStats ExecuteHandlerOnGraph(
56-
GraphPatternDetector* gpd, const GraphWithStats& graph_with_stats,
57-
const ConvFunc& get_node_from_conv,
58-
const ElementwiseAddFunc& get_node_from_elementwise_add) const;
63+
std::tuple<Node*, Node*, Node*, Node*> GetNodesFromProjectionConv(
64+
const patterns::Conv& conv_pattern,
65+
const GraphPatternDetector::subgraph_t& subgraph) const;
66+
67+
template <typename HandleType, typename... OpFuncs>
68+
GraphWithStats ExecuteHandleOnGraph(GraphPatternDetector* gpd,
69+
const GraphWithStats& graph_with_stats,
70+
OpFuncs&&... op_funcs) const {
71+
ir::Graph* graph;
72+
int stats;
73+
74+
std::tie(graph, stats) = graph_with_stats;
75+
76+
auto can_fuse = [this](Node* op1, Node* op2) -> bool {
77+
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
78+
};
79+
80+
auto fuse_handle = HandleType{can_fuse, std::forward<OpFuncs>(op_funcs)...};
81+
82+
(*gpd)(graph, fuse_handle);
83+
84+
return std::make_pair(graph, stats + fuse_handle.get_stats());
85+
}
86+
87+
struct IdentityFuseHandle {
88+
IdentityFuseHandle(
89+
const CanFuseFunc& can_fuse_func,
90+
const IdentityConvFunc& get_node_from_conv_op,
91+
const IdentityElementwiseAddFunc& get_node_from_elementwise_add_op);
92+
93+
void operator()(const GraphPatternDetector::subgraph_t& subgraph,
94+
Graph* graph);
95+
int get_stats() const { return *fusion_stats; }
96+
97+
private:
98+
std::shared_ptr<int> fusion_stats;
99+
CanFuseFunc can_fuse_func;
100+
IdentityConvFunc get_node_from_conv_op;
101+
IdentityElementwiseAddFunc get_node_from_elementwise_add_op;
102+
};
59103

60-
struct FuseHandler {
61-
FuseHandler(const ConvFunc& get_node_from_conv_op,
62-
const ElementwiseAddFunc& get_node_from_elementwise_add_op,
63-
const CanFuseFunc& can_fuse_func);
104+
struct ProjectionFuseHandle {
105+
ProjectionFuseHandle(
106+
const CanFuseFunc& can_fuse_func,
107+
const ProjectionConvFunc& get_node_from_conv_x_op,
108+
const ProjectionConvFunc& get_node_from_conv_y_op,
109+
const ProjectionElementwiseAddFunc& get_node_from_elementwise_add_op);
64110

65111
void operator()(const GraphPatternDetector::subgraph_t& subgraph,
66112
Graph* graph);
67113
int get_stats() const { return *fusion_stats; }
68114

69115
private:
70116
std::shared_ptr<int> fusion_stats;
71-
ConvFunc get_node_from_conv_op;
72-
ElementwiseAddFunc get_node_from_elementwise_add_op;
73117
CanFuseFunc can_fuse_func;
118+
ProjectionConvFunc get_node_from_conv_x_op;
119+
ProjectionConvFunc get_node_from_conv_y_op;
120+
ProjectionElementwiseAddFunc get_node_from_elementwise_add_op;
74121
};
75122

76123
public:

0 commit comments

Comments
 (0)