Skip to content

Commit 7423748

Browse files
author
Tomasz Patejko
committed
MKLDNN residual connections fuse pass:
* implements reachability check between identity node and non-identity argument to elementwise_add * implements handling identity node as x and as y argument to elementwise_add
1 parent 1722678 commit 7423748

File tree

4 files changed

+245
-83
lines changed

4 files changed

+245
-83
lines changed

paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc

Lines changed: 147 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414

1515
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
1616
#include <functional>
17-
#include <utility>
17+
#include <list>
18+
#include <map>
19+
#include <tuple>
1820

1921
#include "paddle/fluid/framework/ir/graph_traits.h"
2022

2123
namespace paddle {
2224
namespace framework {
2325
namespace ir {
24-
namespace {
2526

2627
// The function keeps the graph consistent by replacing
2728
// a node 'from' in the set of inputs nodes
@@ -51,104 +52,179 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
5152
}
5253
}
5354
}
54-
} // namespace
55-
using graph_ptr = std::unique_ptr<ir::Graph>;
5655

57-
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
58-
FusePassBase::Init(name_scope_, graph.get());
56+
bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
57+
auto find_node = [](ir::Graph* graph, const Node* node) -> Node* {
58+
for (auto n : graph->Nodes()) {
59+
if (n == node) {
60+
return n;
61+
}
62+
}
5963

60-
GraphPatternDetector gpd;
61-
auto pattern = gpd.mutable_pattern();
64+
return nullptr;
65+
};
6266

63-
patterns::Conv conv_pattern{pattern, name_scope_};
64-
auto conv_output = conv_pattern();
67+
if (from == to) {
68+
return true;
69+
}
6570

66-
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
67-
elementwise_add_pattern(conv_output);
71+
std::map<Node*, bool> visited;
6872

69-
conv_output->AsIntermediate();
73+
for (auto& node : GraphTraits::DFS(*graph)) {
74+
visited[&node] = false;
75+
}
7076

71-
auto conv_op_has_bias = [](const Node& conv_op) -> std::pair<bool, Node*> {
72-
auto bias_input_names = conv_op.Op()->Inputs();
73-
auto bias_it = bias_input_names.find("Bias");
74-
75-
if (bias_it != std::end(bias_input_names)) {
76-
bool has_bias = !bias_it->second.empty();
77-
78-
if (has_bias) {
79-
auto conv_bias_names = bias_it->second;
80-
auto conv_bias_names_it =
81-
std::find_if(std::begin(conv_op.inputs), std::end(conv_op.inputs),
82-
[&conv_bias_names](Node* n) -> bool {
83-
return n->Name() == conv_bias_names[0];
84-
});
85-
return std::make_pair(has_bias, *conv_bias_names_it);
86-
}
87-
}
77+
visited[from] = true;
8878

89-
return std::make_pair(false, nullptr);
90-
};
79+
std::list<Node*> queue;
80+
queue.push_back(from);
9181

92-
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
93-
Graph* g) {
94-
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
95-
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
96-
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
97-
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
98-
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
99-
elementwise_add_pattern);
100-
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
101-
elementwise_add_pattern);
102-
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
103-
elementwise_add_pattern);
82+
while (!queue.empty()) {
83+
auto cur = find_node(graph, queue.front());
84+
queue.pop_front();
10485

105-
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return;
86+
if (!cur) return false;
10687

107-
OpDesc op_desc;
108-
op_desc.SetType("conv2d");
88+
for (auto n : cur->outputs) {
89+
if (n == to) {
90+
return true;
91+
}
10992

110-
op_desc.SetInput("Input", {conv_input->Name()});
111-
op_desc.SetInput("Filter", {conv_filter->Name()});
112-
op_desc.SetInput("ResidualData", {elementwise_add_x->Name()});
113-
op_desc.SetOutput("Output", {conv_output->Name()});
93+
if (!visited[n]) {
94+
visited[n] = true;
95+
queue.push_back(n);
96+
}
97+
}
98+
}
99+
return false;
100+
}
114101

115-
bool has_bias;
116-
Node* conv_bias;
102+
std::pair<bool, Node*> ResidualConnectionMKLDNNFusePass::HasBias(
103+
const Node& op) const {
104+
auto bias_input_names = op.Op()->Inputs();
105+
auto bias_it = bias_input_names.find("Bias");
117106

118-
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op);
107+
if (bias_it != std::end(bias_input_names)) {
108+
bool has_bias = !bias_it->second.empty();
119109

120110
if (has_bias) {
121-
op_desc.SetInput("Bias", {conv_bias->Name()});
111+
auto bias_names = bias_it->second;
112+
auto bias_names_it =
113+
std::find_if(std::begin(op.inputs), std::end(op.inputs),
114+
[&bias_names](Node* n) -> bool {
115+
return n->Name() == bias_names[0];
116+
});
117+
return std::make_pair(has_bias, *bias_names_it);
122118
}
119+
}
123120

124-
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
125-
op_desc.SetAttr(attr.first, attr.second);
126-
}
121+
return std::make_pair(false, nullptr);
122+
}
127123

128-
op_desc.SetAttr("fuse_residual_connection", true);
124+
graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
125+
const std::string& name_scope_, graph_ptr graph) const {
126+
GraphPatternDetector gpd;
127+
auto pattern = gpd.mutable_pattern();
129128

130-
auto fused_conv_op = g->CreateOpNode(&op_desc);
129+
patterns::Conv conv_pattern{pattern, name_scope_};
130+
auto conv_output = conv_pattern();
131131

132-
IR_NODE_LINK_TO(conv_input, fused_conv_op);
133-
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
134-
IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op);
135-
IR_NODE_LINK_TO(fused_conv_op, conv_output);
132+
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
133+
elementwise_add_pattern(
134+
conv_output,
135+
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
136+
conv_output->AsIntermediate();
136137

137-
if (has_bias) {
138-
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
139-
}
138+
auto get_node_from_conv = [](const patterns::Conv& conv_pattern,
139+
const GraphPatternDetector::subgraph_t& subgraph)
140+
-> std::tuple<Node*, Node*, Node*, Node*> {
141+
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
142+
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
143+
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
144+
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
145+
146+
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
147+
};
148+
149+
auto get_node_from_elementwise_add = [](
150+
const patterns::ElementwiseAdd& elementwise_add_pattern,
151+
const GraphPatternDetector::subgraph_t& subgraph)
152+
-> std::tuple<Node*, Node*, Node*> {
153+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
154+
elementwise_add_pattern);
155+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
156+
elementwise_add_pattern);
157+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
158+
elementwise_add_pattern);
159+
160+
return std::make_tuple(elementwise_add_op, elementwise_add_y,
161+
elementwise_add_out);
162+
};
163+
164+
auto handler =
165+
GenerateFuseHandler(conv_pattern, elementwise_add_pattern,
166+
get_node_from_conv, get_node_from_elementwise_add);
167+
gpd(graph.get(), handler);
140168

141-
CorrectGraphEdges(g, elementwise_add_out, conv_output);
142-
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
143-
};
169+
return graph;
170+
}
171+
172+
graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
173+
const std::string& name_scope_, graph_ptr graph) const {
174+
GraphPatternDetector gpd;
175+
auto pattern = gpd.mutable_pattern();
176+
177+
patterns::Conv conv_pattern{pattern, name_scope_};
178+
auto conv_output = conv_pattern();
179+
180+
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
181+
elementwise_add_pattern(
182+
pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()),
183+
conv_output);
184+
conv_output->AsIntermediate();
144185

186+
auto get_node_from_conv = [](const patterns::Conv& conv_pattern,
187+
const GraphPatternDetector::subgraph_t& subgraph)
188+
-> std::tuple<Node*, Node*, Node*, Node*> {
189+
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
190+
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
191+
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
192+
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
193+
194+
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
195+
};
196+
197+
auto get_node_from_elementwise_add = [](
198+
const patterns::ElementwiseAdd& elementwise_add_pattern,
199+
const GraphPatternDetector::subgraph_t& subgraph)
200+
-> std::tuple<Node*, Node*, Node*> {
201+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
202+
elementwise_add_pattern);
203+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
204+
elementwise_add_pattern);
205+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
206+
elementwise_add_pattern);
207+
208+
return std::make_tuple(elementwise_add_op, elementwise_add_x,
209+
elementwise_add_out);
210+
};
211+
212+
auto handler =
213+
GenerateFuseHandler(conv_pattern, elementwise_add_pattern,
214+
get_node_from_conv, get_node_from_elementwise_add);
145215
gpd(graph.get(), handler);
146216

147217
return graph;
148218
}
219+
220+
graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
221+
FusePassBase::Init(name_scope_, graph.get());
222+
223+
return FuseConvAsY(name_scope_, FuseConvAsX(name_scope_, std::move(graph)));
224+
}
149225
} // namespace ir
150226
} // namespace framework
151227
} // namespace paddle
152228

153229
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
154-
paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass);
230+
paddle::framework::ir::ResidualConnectionMKLDNNFusePass);

paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include <string>
18+
#include <utility>
1819
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
1920
#include "paddle/fluid/framework/ir/graph.h"
2021
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
@@ -23,16 +24,105 @@ namespace paddle {
2324
namespace framework {
2425
namespace ir {
2526

26-
class ConvElementwiseAddMKLDNNFusePass : public FusePassBase {
27+
using graph_ptr = std::unique_ptr<ir::Graph>;
28+
29+
void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
30+
bool IsReachable(ir::Graph* graph, Node* from, Node* to);
31+
32+
using handler_func = std::function<void(
33+
const GraphPatternDetector::subgraph_t& subgraph, Graph* g)>;
34+
35+
class ResidualConnectionMKLDNNFusePass : public FusePassBase {
36+
private:
37+
graph_ptr FuseConvAsX(const std::string& name_scope_, graph_ptr graph) const;
38+
graph_ptr FuseConvAsY(const std::string& name_scope_, graph_ptr graph) const;
39+
40+
std::pair<bool, Node*> HasBias(const Node& op) const;
41+
42+
template <typename CONV_FUNC, typename ELEMENTWISE_ADD_FUNC,
43+
typename HANDLER_FUNC = handler_func>
44+
HANDLER_FUNC GenerateFuseHandler(
45+
const patterns::Conv& conv_pattern,
46+
const patterns::ElementwiseAdd& elementwise_add_pattern,
47+
CONV_FUNC get_node_from_conv_op,
48+
ELEMENTWISE_ADD_FUNC get_node_from_elementwise_add_op) const;
49+
2750
public:
28-
virtual ~ConvElementwiseAddMKLDNNFusePass() {}
51+
virtual ~ResidualConnectionMKLDNNFusePass() {}
2952

3053
protected:
31-
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
54+
std::unique_ptr<ir::Graph> ApplyImpl(graph_ptr graph) const;
3255

33-
const std::string name_scope_{"residual_connections_fuse_pass"};
56+
const std::string name_scope_{"residual_connection_fuse_pass"};
3457
};
3558

59+
template <typename CONV_FUNC, typename ELEMENTWISE_ADD_FUNC,
60+
typename HANDLER_FUNC>
61+
HANDLER_FUNC ResidualConnectionMKLDNNFusePass::GenerateFuseHandler(
62+
const patterns::Conv& conv_pattern,
63+
const patterns::ElementwiseAdd& elementwise_add_pattern,
64+
CONV_FUNC get_node_from_conv_op,
65+
ELEMENTWISE_ADD_FUNC get_node_from_elementwise_add_op) const {
66+
return [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
67+
Node* conv_op;
68+
Node* conv_input;
69+
Node* conv_filter;
70+
Node* conv_output;
71+
72+
Node* elementwise_add_op;
73+
Node* elementwise_add_identity;
74+
Node* elementwise_add_out;
75+
76+
std::tie(conv_op, conv_input, conv_filter, conv_output) =
77+
get_node_from_conv_op(conv_pattern, subgraph);
78+
std::tie(elementwise_add_op, elementwise_add_identity,
79+
elementwise_add_out) =
80+
get_node_from_elementwise_add_op(elementwise_add_pattern, subgraph);
81+
82+
if (this->FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN)
83+
return;
84+
85+
if (!IsReachable(graph, elementwise_add_identity, conv_output)) return;
86+
87+
OpDesc op_desc;
88+
op_desc.SetType("conv2d");
89+
90+
op_desc.SetInput("Input", {conv_input->Name()});
91+
op_desc.SetInput("Filter", {conv_filter->Name()});
92+
op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()});
93+
op_desc.SetOutput("Output", {conv_output->Name()});
94+
95+
bool has_bias;
96+
Node* conv_bias;
97+
98+
std::tie(has_bias, conv_bias) = this->HasBias(*conv_op);
99+
100+
if (has_bias) {
101+
op_desc.SetInput("Bias", {conv_bias->Name()});
102+
}
103+
104+
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
105+
op_desc.SetAttr(attr.first, attr.second);
106+
}
107+
108+
op_desc.SetAttr("fuse_residual_connection", true);
109+
110+
auto fused_conv_op = graph->CreateOpNode(&op_desc);
111+
112+
IR_NODE_LINK_TO(conv_input, fused_conv_op);
113+
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
114+
IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op);
115+
IR_NODE_LINK_TO(fused_conv_op, conv_output);
116+
117+
if (has_bias) {
118+
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
119+
}
120+
121+
CorrectGraphEdges(graph, elementwise_add_out, conv_output);
122+
GraphSafeRemoveNodes(graph,
123+
{elementwise_add_out, conv_op, elementwise_add_op});
124+
};
125+
}
36126
} // namespace ir
37127
} // namespace framework
38128
} // namespace paddle

0 commit comments

Comments
 (0)