Skip to content

Commit da722d6

Browse files
authored
Merge pull request #13858 from Sand3r-/mgallus/conv-bias-pass
[MKLDNN] Fuse Conv + Bias using Pass
2 parents a4b48f7 + f9ca318 commit da722d6

File tree

6 files changed

+229
-0
lines changed

6 files changed

+229
-0
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ pass_library(seq_concat_fc_fuse_pass inference)
3939
pass_library(conv_bn_fuse_pass inference)
4040
if(WITH_MKLDNN)
4141
pass_library(mkldnn_placement_pass base)
42+
pass_library(conv_bias_mkldnn_fuse_pass inference)
4243
pass_library(conv_relu_mkldnn_fuse_pass inference)
4344
endif()
4445

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
16+
#include <functional>
17+
#include <string>
18+
#include <vector>
19+
#include "paddle/fluid/framework/lod_tensor.h"
20+
#include "paddle/fluid/platform/enforce.h"
21+
22+
namespace paddle {
23+
namespace framework {
24+
namespace ir {
25+
26+
template <typename BinaryOperation>
27+
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
28+
BinaryOperation f) {
29+
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims());
30+
LoDTensor vec_y;
31+
vec_y.Resize(vec_a.dims());
32+
const float* a = vec_a.data<float>();
33+
const float* b = vec_b.data<float>();
34+
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
35+
for (int i = 0; i < vec_a.numel(); i++) {
36+
y[i] = f(a[i], b[i]);
37+
}
38+
return vec_y;
39+
}
40+
41+
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
42+
std::unique_ptr<ir::Graph> graph) const {
43+
PADDLE_ENFORCE(graph.get());
44+
FusePassBase::Init(name_scope_, graph.get());
45+
46+
auto* scope = param_scope();
47+
PADDLE_ENFORCE(scope);
48+
49+
GraphPatternDetector gpd;
50+
auto* conv_input =
51+
gpd.mutable_pattern()
52+
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
53+
->AsInput()
54+
->assert_is_op_input("conv2d", "Input");
55+
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
56+
conv_bias_pattern(conv_input);
57+
int found_conv_bias_count = 0;
58+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
59+
Graph* g) {
60+
VLOG(4) << "handle ConvBias fuse";
61+
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
62+
conv_bias_pattern); // Filter
63+
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_bias_pattern); // tmp
64+
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_bias_pattern); // CONV op
65+
// bias
66+
GET_IR_NODE_FROM_SUBGRAPH(eltwise_bias, eltwise_bias, conv_bias_pattern);
67+
// output
68+
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bias_pattern);
69+
// elementwise_add op
70+
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern);
71+
72+
PADDLE_ENFORCE(subgraph.count(conv_input));
73+
74+
// check if fuse can be done and if MKL-DNN should be used
75+
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
76+
if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) {
77+
VLOG(3) << "do not perform conv+bias fuse";
78+
return;
79+
}
80+
81+
auto* eltwise_bias_tensor =
82+
scope->FindVar(eltwise_bias->Name())->GetMutable<LoDTensor>();
83+
84+
auto input_names = conv->Op()->InputNames();
85+
bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") !=
86+
input_names.end();
87+
if (has_bias && conv->Op()->Input("Bias").size() > 0) {
88+
auto conv_bias_names = conv->Op()->Input("Bias");
89+
// add eltwise bias to existing conv bias
90+
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1);
91+
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
92+
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
93+
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), eltwise_bias_tensor->dims());
94+
*conv_bias_tensor = tensor_apply_eltwise(
95+
*conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>());
96+
97+
conv->Op()->SetOutput("Output",
98+
std::vector<std::string>({eltwise_out->Name()}));
99+
100+
GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out});
101+
102+
IR_NODE_LINK_TO(conv, eltwise_out);
103+
} else {
104+
// take eltwise bias as conv bias
105+
OpDesc desc;
106+
107+
desc.SetInput(
108+
"Input", std::vector<std::string>({subgraph.at(conv_input)->Name()}));
109+
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
110+
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
111+
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
112+
desc.SetType("conv2d");
113+
114+
for (auto& attr : conv->Op()->GetAttrMap()) {
115+
desc.SetAttr(attr.first, attr.second);
116+
}
117+
auto conv_bias_node = g->CreateOpNode(&desc);
118+
119+
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node);
120+
IR_NODE_LINK_TO(conv_weight, conv_bias_node);
121+
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
122+
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
123+
124+
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
125+
}
126+
127+
found_conv_bias_count++;
128+
};
129+
gpd(graph.get(), handler);
130+
AddStatis(found_conv_bias_count);
131+
return graph;
132+
}
133+
} // namespace ir
134+
} // namespace framework
135+
} // namespace paddle
136+
REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
137+
paddle::framework::ir::ConvBiasFusePass);
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#pragma once
15+
#include <string>
16+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
17+
#include "paddle/fluid/framework/ir/graph.h"
18+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
19+
#include "paddle/fluid/framework/ir/pass.h"
20+
namespace paddle {
21+
namespace framework {
22+
namespace ir {
23+
/*
24+
* Fuse the Conv and Elementwise_add to a ConvBiasOp.
25+
*/
26+
class ConvBiasFusePass : public FusePassBase {
27+
public:
28+
virtual ~ConvBiasFusePass() {}
29+
30+
protected:
31+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
32+
const std::string name_scope_{"conv_bias_mkldnn_fuse"};
33+
};
34+
} // namespace ir
35+
} // namespace framework
36+
} // namespace paddle

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,39 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
966966
return ele_add_grad;
967967
}
968968

969+
PDNode *patterns::ConvBias::operator()(
970+
paddle::framework::ir::PDNode *conv_input) {
971+
// Create Operators
972+
conv_input->assert_is_op_input("conv2d", "Input");
973+
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
974+
auto *eltiwse_op =
975+
pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add");
976+
// Create variables
977+
// Filter
978+
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
979+
->AsInput()
980+
->assert_is_persistable_var()
981+
->assert_is_op_input("conv2d", "Filter");
982+
// intermediate variable, will be removed in the IR after fuse.
983+
auto *conv_out_var = pattern->NewNode(conv_out_repr())
984+
->AsIntermediate()
985+
->assert_is_only_output_of_op("conv2d")
986+
->assert_is_op_input("elementwise_add");
987+
// Bias stored in elementwise_add
988+
auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr())
989+
->AsInput()
990+
->assert_is_persistable_var()
991+
->assert_is_op_input("elementwise_add", "Y");
992+
// output
993+
auto *eltwise_out_var = pattern->NewNode(eltwise_out_repr())
994+
->AsOutput()
995+
->assert_is_op_output("elementwise_add");
996+
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
997+
eltiwse_op->LinksFrom({conv_out_var, eltwise_bias_var})
998+
.LinksTo({eltwise_out_var});
999+
return eltwise_out_var;
1000+
}
1001+
9691002
} // namespace ir
9701003
} // namespace framework
9711004
} // namespace paddle

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,27 @@ struct ElewiseAddActInplaceGrad : public PatternBase {
578578
PATTERN_DECL_NODE(d_ele_y);
579579
PATTERN_DECL_NODE(ele_y);
580580
};
581+
582+
// Conv with Elementwise_add as bias
583+
// op: conv + elementwise_add
584+
// named nodes:
585+
// conv_input, conv_weight,
586+
// conv_out, conv,
587+
// eltwise_bias, eltwise_out,
588+
// elementwise_add
589+
struct ConvBias : public PatternBase {
590+
ConvBias(PDPattern* pattern, const std::string& name_scope)
591+
: PatternBase(pattern, name_scope, "conv_bias") {}
592+
PDNode* operator()(PDNode* conv_input);
593+
// declare operator node's name
594+
PATTERN_DECL_NODE(conv);
595+
PATTERN_DECL_NODE(eltwise);
596+
// declare variable node's name
597+
PATTERN_DECL_NODE(conv_weight);
598+
PATTERN_DECL_NODE(conv_out);
599+
PATTERN_DECL_NODE(eltwise_bias);
600+
PATTERN_DECL_NODE(eltwise_out);
601+
};
581602
} // namespace patterns
582603

583604
// Link two ir::Nodes from each other.

paddle/fluid/inference/analysis/analyzer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class Analyzer : public OrderedRegistry<PassManager> {
7979
"conv_bn_fuse_pass", //
8080
"conv_eltwiseadd_bn_fuse_pass", //
8181
#ifdef PADDLE_WITH_MKLDNN
82+
"conv_bias_mkldnn_fuse_pass", //
8283
"conv_relu_mkldnn_fuse_pass", //
8384
#endif
8485
}};

0 commit comments

Comments
 (0)