Skip to content

Commit ee6f778

Browse files
author
Tomasz Patejko
committed
MKLDNN residual connections fuse pass: further refactoring
1 parent 7423748 commit ee6f778

File tree

2 files changed

+112
-98
lines changed

2 files changed

+112
-98
lines changed

paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc

Lines changed: 92 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,9 @@ bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
9999
return false;
100100
}
101101

102-
std::pair<bool, Node*> ResidualConnectionMKLDNNFusePass::HasBias(
103-
const Node& op) const {
102+
std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name) {
104103
auto bias_input_names = op.Op()->Inputs();
105-
auto bias_it = bias_input_names.find("Bias");
104+
auto bias_it = bias_input_names.find(bias_name);
106105

107106
if (bias_it != std::end(bias_input_names)) {
108107
bool has_bias = !bias_it->second.empty();
@@ -121,6 +120,74 @@ std::pair<bool, Node*> ResidualConnectionMKLDNNFusePass::HasBias(
121120
return std::make_pair(false, nullptr);
122121
}
123122

123+
ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler(
124+
const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv_op,
125+
const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc&
126+
get_node_from_elementwise_add_op,
127+
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func)
128+
: get_node_from_conv_op{get_node_from_conv_op},
129+
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
130+
can_fuse_func{can_fuse_func} {}
131+
132+
void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
133+
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
134+
Node* conv_op;
135+
Node* conv_input;
136+
Node* conv_filter;
137+
Node* conv_output;
138+
139+
Node* elementwise_add_op;
140+
Node* elementwise_add_identity;
141+
Node* elementwise_add_out;
142+
143+
std::tie(conv_op, conv_input, conv_filter, conv_output) =
144+
get_node_from_conv_op(subgraph);
145+
std::tie(elementwise_add_op, elementwise_add_identity, elementwise_add_out) =
146+
get_node_from_elementwise_add_op(subgraph);
147+
148+
if (!can_fuse_func(conv_op, elementwise_add_op)) return;
149+
150+
if (!IsReachable(graph, elementwise_add_identity, conv_output)) return;
151+
152+
OpDesc op_desc;
153+
op_desc.SetType("conv2d");
154+
155+
op_desc.SetInput("Input", {conv_input->Name()});
156+
op_desc.SetInput("Filter", {conv_filter->Name()});
157+
op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()});
158+
op_desc.SetOutput("Output", {conv_output->Name()});
159+
160+
bool has_bias;
161+
Node* conv_bias;
162+
163+
std::tie(has_bias, conv_bias) = HasBias(*conv_op, "Bias");
164+
165+
if (has_bias) {
166+
op_desc.SetInput("Bias", {conv_bias->Name()});
167+
}
168+
169+
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
170+
op_desc.SetAttr(attr.first, attr.second);
171+
}
172+
173+
op_desc.SetAttr("fuse_residual_connection", true);
174+
175+
auto fused_conv_op = graph->CreateOpNode(&op_desc);
176+
177+
IR_NODE_LINK_TO(conv_input, fused_conv_op);
178+
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
179+
IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op);
180+
IR_NODE_LINK_TO(fused_conv_op, conv_output);
181+
182+
if (has_bias) {
183+
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
184+
}
185+
186+
CorrectGraphEdges(graph, elementwise_add_out, conv_output);
187+
GraphSafeRemoveNodes(graph,
188+
{elementwise_add_out, conv_op, elementwise_add_op});
189+
}
190+
124191
graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
125192
const std::string& name_scope_, graph_ptr graph) const {
126193
GraphPatternDetector gpd;
@@ -135,8 +202,8 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
135202
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
136203
conv_output->AsIntermediate();
137204

138-
auto get_node_from_conv = [](const patterns::Conv& conv_pattern,
139-
const GraphPatternDetector::subgraph_t& subgraph)
205+
auto get_node_from_conv =
206+
[&conv_pattern](const GraphPatternDetector::subgraph_t& subgraph)
140207
-> std::tuple<Node*, Node*, Node*, Node*> {
141208
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
142209
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
@@ -146,8 +213,7 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
146213
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
147214
};
148215

149-
auto get_node_from_elementwise_add = [](
150-
const patterns::ElementwiseAdd& elementwise_add_pattern,
216+
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
151217
const GraphPatternDetector::subgraph_t& subgraph)
152218
-> std::tuple<Node*, Node*, Node*> {
153219
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
@@ -161,10 +227,14 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
161227
elementwise_add_out);
162228
};
163229

164-
auto handler =
165-
GenerateFuseHandler(conv_pattern, elementwise_add_pattern,
166-
get_node_from_conv, get_node_from_elementwise_add);
167-
gpd(graph.get(), handler);
230+
auto can_fuse = [this](Node* op1, Node* op2) -> bool {
231+
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
232+
};
233+
234+
auto fuse_handler =
235+
FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse};
236+
237+
gpd(graph.get(), fuse_handler);
168238

169239
return graph;
170240
}
@@ -183,8 +253,8 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
183253
conv_output);
184254
conv_output->AsIntermediate();
185255

186-
auto get_node_from_conv = [](const patterns::Conv& conv_pattern,
187-
const GraphPatternDetector::subgraph_t& subgraph)
256+
auto get_node_from_conv =
257+
[&conv_pattern](const GraphPatternDetector::subgraph_t& subgraph)
188258
-> std::tuple<Node*, Node*, Node*, Node*> {
189259
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
190260
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
@@ -194,8 +264,7 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
194264
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
195265
};
196266

197-
auto get_node_from_elementwise_add = [](
198-
const patterns::ElementwiseAdd& elementwise_add_pattern,
267+
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
199268
const GraphPatternDetector::subgraph_t& subgraph)
200269
-> std::tuple<Node*, Node*, Node*> {
201270
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
@@ -209,10 +278,14 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
209278
elementwise_add_out);
210279
};
211280

212-
auto handler =
213-
GenerateFuseHandler(conv_pattern, elementwise_add_pattern,
214-
get_node_from_conv, get_node_from_elementwise_add);
215-
gpd(graph.get(), handler);
281+
auto can_fuse = [this](Node* op1, Node* op2) -> bool {
282+
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
283+
};
284+
285+
auto fuse_handler =
286+
FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse};
287+
288+
gpd(graph.get(), fuse_handler);
216289

217290
return graph;
218291
}

paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h

Lines changed: 20 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include <string>
18+
#include <tuple>
1819
#include <utility>
1920
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
2021
#include "paddle/fluid/framework/ir/graph.h"
@@ -28,24 +29,32 @@ using graph_ptr = std::unique_ptr<ir::Graph>;
2829

2930
void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
3031
bool IsReachable(ir::Graph* graph, Node* from, Node* to);
31-
32-
using handler_func = std::function<void(
33-
const GraphPatternDetector::subgraph_t& subgraph, Graph* g)>;
32+
std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name);
3433

3534
class ResidualConnectionMKLDNNFusePass : public FusePassBase {
3635
private:
3736
graph_ptr FuseConvAsX(const std::string& name_scope_, graph_ptr graph) const;
3837
graph_ptr FuseConvAsY(const std::string& name_scope_, graph_ptr graph) const;
3938

40-
std::pair<bool, Node*> HasBias(const Node& op) const;
39+
template <typename RetType>
40+
using GetNodeFunc =
41+
std::function<RetType(const GraphPatternDetector::subgraph_t& subgraph)>;
42+
using ConvFunc = GetNodeFunc<std::tuple<Node*, Node*, Node*, Node*>>;
43+
using ElementwiseAddFunc = GetNodeFunc<std::tuple<Node*, Node*, Node*>>;
44+
using CanFuseFunc = std::function<bool(Node*, Node*)>;
45+
46+
struct FuseHandler {
47+
FuseHandler(const ConvFunc& get_node_from_conv_op,
48+
const ElementwiseAddFunc& get_node_from_elementwise_add_op,
49+
const CanFuseFunc& can_fuse_func);
50+
51+
ConvFunc get_node_from_conv_op;
52+
ElementwiseAddFunc get_node_from_elementwise_add_op;
53+
CanFuseFunc can_fuse_func;
4154

42-
template <typename CONV_FUNC, typename ELEMENTWISE_ADD_FUNC,
43-
typename HANDLER_FUNC = handler_func>
44-
HANDLER_FUNC GenerateFuseHandler(
45-
const patterns::Conv& conv_pattern,
46-
const patterns::ElementwiseAdd& elementwise_add_pattern,
47-
CONV_FUNC get_node_from_conv_op,
48-
ELEMENTWISE_ADD_FUNC get_node_from_elementwise_add_op) const;
55+
void operator()(const GraphPatternDetector::subgraph_t& subgraph,
56+
Graph* graph);
57+
};
4958

5059
public:
5160
virtual ~ResidualConnectionMKLDNNFusePass() {}
@@ -55,74 +64,6 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
5564

5665
const std::string name_scope_{"residual_connection_fuse_pass"};
5766
};
58-
59-
template <typename CONV_FUNC, typename ELEMENTWISE_ADD_FUNC,
60-
typename HANDLER_FUNC>
61-
HANDLER_FUNC ResidualConnectionMKLDNNFusePass::GenerateFuseHandler(
62-
const patterns::Conv& conv_pattern,
63-
const patterns::ElementwiseAdd& elementwise_add_pattern,
64-
CONV_FUNC get_node_from_conv_op,
65-
ELEMENTWISE_ADD_FUNC get_node_from_elementwise_add_op) const {
66-
return [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
67-
Node* conv_op;
68-
Node* conv_input;
69-
Node* conv_filter;
70-
Node* conv_output;
71-
72-
Node* elementwise_add_op;
73-
Node* elementwise_add_identity;
74-
Node* elementwise_add_out;
75-
76-
std::tie(conv_op, conv_input, conv_filter, conv_output) =
77-
get_node_from_conv_op(conv_pattern, subgraph);
78-
std::tie(elementwise_add_op, elementwise_add_identity,
79-
elementwise_add_out) =
80-
get_node_from_elementwise_add_op(elementwise_add_pattern, subgraph);
81-
82-
if (this->FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN)
83-
return;
84-
85-
if (!IsReachable(graph, elementwise_add_identity, conv_output)) return;
86-
87-
OpDesc op_desc;
88-
op_desc.SetType("conv2d");
89-
90-
op_desc.SetInput("Input", {conv_input->Name()});
91-
op_desc.SetInput("Filter", {conv_filter->Name()});
92-
op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()});
93-
op_desc.SetOutput("Output", {conv_output->Name()});
94-
95-
bool has_bias;
96-
Node* conv_bias;
97-
98-
std::tie(has_bias, conv_bias) = this->HasBias(*conv_op);
99-
100-
if (has_bias) {
101-
op_desc.SetInput("Bias", {conv_bias->Name()});
102-
}
103-
104-
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
105-
op_desc.SetAttr(attr.first, attr.second);
106-
}
107-
108-
op_desc.SetAttr("fuse_residual_connection", true);
109-
110-
auto fused_conv_op = graph->CreateOpNode(&op_desc);
111-
112-
IR_NODE_LINK_TO(conv_input, fused_conv_op);
113-
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
114-
IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op);
115-
IR_NODE_LINK_TO(fused_conv_op, conv_output);
116-
117-
if (has_bias) {
118-
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
119-
}
120-
121-
CorrectGraphEdges(graph, elementwise_add_out, conv_output);
122-
GraphSafeRemoveNodes(graph,
123-
{elementwise_add_out, conv_op, elementwise_add_op});
124-
};
125-
}
12667
} // namespace ir
12768
} // namespace framework
12869
} // namespace paddle

0 commit comments

Comments
 (0)