Skip to content

Commit 4224089

Browse files
author
Tomasz Patejko
committed
MKLDNN residual connections fuse pass: Maybe removed and boost::optional used where it makes sense
1 parent 86fd3b3 commit 4224089

File tree

2 files changed

+81
-88
lines changed

2 files changed

+81
-88
lines changed

paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc

Lines changed: 70 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
9999
return false;
100100
}
101101

102-
std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name) {
102+
boost::optional<Node*> HasBias(const Node& op, const std::string& bias_name) {
103103
auto bias_input_names = op.Op()->Inputs();
104104
auto bias_it = bias_input_names.find(bias_name);
105105

@@ -113,19 +113,20 @@ std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name) {
113113
[&bias_names](Node* n) -> bool {
114114
return n->Name() == bias_names[0];
115115
});
116-
return std::make_pair(has_bias, *bias_names_it);
116+
return *bias_names_it;
117117
}
118118
}
119119

120-
return std::make_pair(false, nullptr);
120+
return boost::none;
121121
}
122122

123123
ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler(
124124
const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv_op,
125125
const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc&
126126
get_node_from_elementwise_add_op,
127127
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func)
128-
: get_node_from_conv_op{get_node_from_conv_op},
128+
: fusion_stats{std::make_shared<int>(0)},
129+
get_node_from_conv_op{get_node_from_conv_op},
129130
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
130131
can_fuse_func{can_fuse_func} {}
131132

@@ -157,13 +158,10 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
157158
op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()});
158159
op_desc.SetOutput("Output", {conv_output->Name()});
159160

160-
bool has_bias;
161-
Node* conv_bias;
161+
auto conv_bias = HasBias(*conv_op, "Bias");
162162

163-
std::tie(has_bias, conv_bias) = HasBias(*conv_op, "Bias");
164-
165-
if (has_bias) {
166-
op_desc.SetInput("Bias", {conv_bias->Name()});
163+
if (conv_bias) {
164+
op_desc.SetInput("Bias", {(*conv_bias)->Name()});
167165
}
168166

169167
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
@@ -179,40 +177,48 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
179177
IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op);
180178
IR_NODE_LINK_TO(fused_conv_op, conv_output);
181179

182-
if (has_bias) {
183-
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
180+
if (conv_bias) {
181+
IR_NODE_LINK_TO((*conv_bias), fused_conv_op);
184182
}
185183

186184
CorrectGraphEdges(graph, elementwise_add_out, conv_output);
187185
GraphSafeRemoveNodes(graph,
188186
{elementwise_add_out, conv_op, elementwise_add_op});
187+
(*fusion_stats)++;
188+
}
189+
190+
std::tuple<Node*, Node*, Node*, Node*>
191+
ResidualConnectionMKLDNNFusePass::GetNodesFromConv(
192+
const patterns::Conv& conv_pattern,
193+
const GraphPatternDetector::subgraph_t& subgraph) const {
194+
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
195+
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
196+
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
197+
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
198+
199+
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
189200
}
190201

191-
graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
192-
const std::string& name_scope_, graph_ptr graph) const {
202+
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
203+
const std::string& name_scope,
204+
const GraphWithStats& graph_with_stats) const {
205+
ir::Graph* graph;
206+
int stats;
207+
208+
std::tie(graph, stats) = graph_with_stats;
209+
193210
GraphPatternDetector gpd;
194211
auto pattern = gpd.mutable_pattern();
195212

196-
patterns::Conv conv_pattern{pattern, name_scope_};
213+
patterns::Conv conv_pattern{pattern, name_scope};
197214
auto conv_output = conv_pattern();
198215

199-
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
216+
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope};
200217
elementwise_add_pattern(
201218
conv_output,
202219
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
203220
conv_output->AsIntermediate();
204221

205-
auto get_node_from_conv =
206-
[&conv_pattern](const GraphPatternDetector::subgraph_t& subgraph)
207-
-> std::tuple<Node*, Node*, Node*, Node*> {
208-
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
209-
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
210-
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
211-
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
212-
213-
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
214-
};
215-
216222
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
217223
const GraphPatternDetector::subgraph_t& subgraph)
218224
-> std::tuple<Node*, Node*, Node*> {
@@ -227,43 +233,29 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
227233
elementwise_add_out);
228234
};
229235

230-
auto can_fuse = [this](Node* op1, Node* op2) -> bool {
231-
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
232-
};
233-
234-
auto fuse_handler =
235-
FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse};
236-
237-
gpd(graph.get(), fuse_handler);
238-
239-
return graph;
236+
return ExecuteHandlerOnGraph(
237+
&gpd, graph_with_stats,
238+
[this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
239+
return GetNodesFromConv(conv_pattern, subgraph);
240+
},
241+
get_node_from_elementwise_add);
240242
}
241243

242-
graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
243-
const std::string& name_scope_, graph_ptr graph) const {
244+
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
245+
const std::string& name_scope,
246+
const GraphWithStats& graph_with_stats) const {
244247
GraphPatternDetector gpd;
245248
auto pattern = gpd.mutable_pattern();
246249

247-
patterns::Conv conv_pattern{pattern, name_scope_};
250+
patterns::Conv conv_pattern{pattern, name_scope};
248251
auto conv_output = conv_pattern();
249252

250-
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
253+
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope};
251254
elementwise_add_pattern(
252255
pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()),
253256
conv_output);
254257
conv_output->AsIntermediate();
255258

256-
auto get_node_from_conv =
257-
[&conv_pattern](const GraphPatternDetector::subgraph_t& subgraph)
258-
-> std::tuple<Node*, Node*, Node*, Node*> {
259-
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
260-
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
261-
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
262-
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
263-
264-
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
265-
};
266-
267259
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
268260
const GraphPatternDetector::subgraph_t& subgraph)
269261
-> std::tuple<Node*, Node*, Node*> {
@@ -278,22 +270,45 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
278270
elementwise_add_out);
279271
};
280272

273+
return ExecuteHandlerOnGraph(
274+
&gpd, graph_with_stats,
275+
[this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
276+
return GetNodesFromConv(conv_pattern, subgraph);
277+
},
278+
get_node_from_elementwise_add);
279+
}
280+
281+
GraphWithStats ResidualConnectionMKLDNNFusePass::ExecuteHandlerOnGraph(
282+
GraphPatternDetector* gpd, const GraphWithStats& graph_with_stats,
283+
const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv,
284+
const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc&
285+
get_node_from_elementwise_add) const {
286+
ir::Graph* graph;
287+
int stats;
288+
289+
std::tie(graph, stats) = graph_with_stats;
290+
281291
auto can_fuse = [this](Node* op1, Node* op2) -> bool {
282292
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
283293
};
284294

285295
auto fuse_handler =
286296
FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse};
287297

288-
gpd(graph.get(), fuse_handler);
298+
(*gpd)(graph, fuse_handler);
289299

290-
return graph;
300+
return std::make_pair(graph, stats + fuse_handler.get_stats());
291301
}
292302

293303
graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
294304
FusePassBase::Init(name_scope_, graph.get());
295305

296-
return FuseConvAsY(name_scope_, FuseConvAsX(name_scope_, std::move(graph)));
306+
auto fused_graph_with_stats = FuseConvAsY(
307+
name_scope_, FuseConvAsX(name_scope_, std::make_pair(graph.get(), 0)));
308+
309+
std::cout << "Fused graph " << fused_graph_with_stats.second << std::endl;
310+
AddStatis(fused_graph_with_stats.second);
311+
return graph;
297312
}
298313
} // namespace ir
299314
} // namespace framework

paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,43 +27,12 @@ namespace paddle {
2727
namespace framework {
2828
namespace ir {
2929

30-
// poor replacement for C++17 std::optional and Boost.Optional
31-
struct InPlace {};
32-
InPlace in_place;
33-
34-
template <typename T>
35-
class Maybe {
36-
private:
37-
typename std::aligned_storage<sizeof(T), alignof(T)>::type data;
38-
bool is_initialized{false};
39-
40-
public:
41-
template <typename... Args>
42-
explicit Maybe(InPlace, Args&&... args) {
43-
new (&data) T(std::forward<Args>(args)...);
44-
is_initialized = true;
45-
}
46-
47-
Maybe() {}
48-
49-
operator bool() { return is_initialized; }
50-
51-
T& value() { return *reinterpret_cast<T*>(&data); }
52-
53-
~Maybe() { reinterpret_cast<T*>(&data)->~T(); }
54-
};
55-
56-
template <typename T, typename... Args>
57-
Maybe<T> MakeMaybe(Args&&... args) {
58-
return Maybe<T>(in_place, std::forward<Args>(args)...);
59-
}
60-
6130
using graph_ptr = std::unique_ptr<ir::Graph>;
62-
using GraphWithStats = std::pair<ir::Graph*, Maybe<int>>;
31+
using GraphWithStats = std::pair<ir::Graph*, int>;
6332

6433
void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
6534
bool IsReachable(ir::Graph* graph, Node* from, Node* to);
66-
std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name);
35+
boost::optional<Node*> HasBias(const Node& op, const std::string& bias_name);
6736

6837
class ResidualConnectionMKLDNNFusePass : public FusePassBase {
6938
private:
@@ -79,6 +48,15 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
7948
using ElementwiseAddFunc = GetNodeFunc<std::tuple<Node*, Node*, Node*>>;
8049
using CanFuseFunc = std::function<bool(Node*, Node*)>;
8150

51+
std::tuple<Node*, Node*, Node*, Node*> GetNodesFromConv(
52+
const patterns::Conv& conv_pattern,
53+
const GraphPatternDetector::subgraph_t& subgraph) const;
54+
55+
GraphWithStats ExecuteHandlerOnGraph(
56+
GraphPatternDetector* gpd, const GraphWithStats& graph_with_stats,
57+
const ConvFunc& get_node_from_conv,
58+
const ElementwiseAddFunc& get_node_from_elementwise_add) const;
59+
8260
struct FuseHandler {
8361
FuseHandler(const ConvFunc& get_node_from_conv_op,
8462
const ElementwiseAddFunc& get_node_from_elementwise_add_op,

0 commit comments

Comments
 (0)