Skip to content

Commit ce248a1

Browse files
authored
Merge pull request #13368 from Sand3r-/mgallus/conv-bias-pass
[MKLDNN] Pass: Fuse Conv + Bias
2 parents 7e651c8 + 40b17be commit ce248a1

File tree

7 files changed

+276
-2
lines changed

7 files changed

+276
-2
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pass_library(graph_to_program_pass base)
3030
pass_library(graph_viz_pass base)
3131
pass_library(fc_fuse_pass inference)
3232
if (WITH_MKLDNN)
33+
pass_library(conv_bias_mkldnn_fuse_pass inference)
3334
pass_library(conv_relu_mkldnn_fuse_pass inference)
3435
endif ()
3536
pass_library(attention_lstm_fuse_pass inference)
@@ -52,6 +53,7 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r
5253
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
5354
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
5455
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
55-
if (WITH_MKLDNN)
56+
if(WITH_MKLDNN)
57+
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass)
5658
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
57-
endif ()
59+
endif()
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
15+
#include <string>
16+
#include <vector>
17+
#include "paddle/fluid/platform/enforce.h"
18+
namespace paddle {
19+
namespace framework {
20+
namespace ir {
21+
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
22+
std::unique_ptr<ir::Graph> graph) const {
23+
PADDLE_ENFORCE(graph.get());
24+
FusePassBase::Init("conv_bias_mkldnn_fuse", graph.get());
25+
GraphPatternDetector gpd;
26+
auto* conv_input = gpd.mutable_pattern()
27+
->NewNode("conv_bias_mkldnn_fuse/conv_input")
28+
->AsInput()
29+
->assert_is_op_input("conv2d", "Input");
30+
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(),
31+
"conv_bias_mkldnn_fuse");
32+
conv_bias_pattern(conv_input);
33+
int found_conv_bias_count = 0;
34+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
35+
Graph* g) {
36+
VLOG(4) << "handle ConvBias fuse";
37+
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
38+
conv_bias_pattern); // Filter
39+
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_bias_pattern); // tmp
40+
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_bias_pattern); // CONV op
41+
// bias
42+
GET_IR_NODE_FROM_SUBGRAPH(eltwise_bias, eltwise_bias, conv_bias_pattern);
43+
// output
44+
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bias_pattern);
45+
// elementwise_add op
46+
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern);
47+
// Create an ConvBias Node.
48+
OpDesc desc;
49+
std::string conv_bias_i_in = subgraph.at(conv_input)->Name();
50+
std::string conv_bias_w_in = conv_weight->Name();
51+
std::string conv_bias_b_in = eltwise_bias->Name();
52+
std::string conv_bias_out = eltwise_out->Name();
53+
desc.SetInput("Input", std::vector<std::string>({conv_bias_i_in}));
54+
desc.SetInput("Filter", std::vector<std::string>({conv_bias_w_in}));
55+
desc.SetInput("Bias", std::vector<std::string>({conv_bias_b_in}));
56+
desc.SetOutput("Output", std::vector<std::string>({conv_bias_out}));
57+
desc.SetType("conv2d");
58+
for (auto& attr : conv->Op()->GetAttrMap()) {
59+
desc.SetAttr(attr.first, attr.second);
60+
}
61+
auto conv_bias_node = g->CreateOpNode(&desc); // OpDesc will be copied.
62+
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
63+
PADDLE_ENFORCE(subgraph.count(conv_input));
64+
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node);
65+
IR_NODE_LINK_TO(conv_weight, conv_bias_node);
66+
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
67+
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
68+
found_conv_bias_count++;
69+
};
70+
gpd(graph.get(), handler);
71+
AddStatis(found_conv_bias_count);
72+
return graph;
73+
}
74+
} // namespace ir
75+
} // namespace framework
76+
} // namespace paddle
77+
REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
78+
paddle::framework::ir::ConvBiasFusePass);
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#pragma once
15+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
16+
#include "paddle/fluid/framework/ir/graph.h"
17+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
18+
#include "paddle/fluid/framework/ir/pass.h"
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
/*
23+
* Fuse the Conv and Elementwise_add to a ConvBiasOp.
24+
*/
25+
class ConvBiasFusePass : public FusePassBase {
26+
public:
27+
virtual ~ConvBiasFusePass() {}
28+
29+
protected:
30+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
31+
};
32+
} // namespace ir
33+
} // namespace framework
34+
} // namespace paddle
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
16+
17+
#include <gtest/gtest.h>
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
void SetOp(ProgramDesc* prog, const std::string& type,
24+
const std::vector<std::string>& inputs,
25+
const std::vector<std::string>& outputs) {
26+
auto* op = prog->MutableBlock(0)->AppendOp();
27+
op->SetType(type);
28+
if (type == "conv2d") {
29+
op->SetAttr("use_mkldnn", true);
30+
op->SetInput("Input", {inputs[0]});
31+
op->SetInput("Filter", {inputs[1]});
32+
} else if (type == "elementwise_add") {
33+
op->SetInput("X", {inputs[0]});
34+
op->SetInput("Y", {inputs[1]});
35+
}
36+
op->SetOutput("Out", outputs);
37+
}
38+
39+
// a->OP0->b
40+
// b->OP1->c
41+
// (c, weights)->conv->f
42+
// (f, bias)->elementwise_add->g
43+
ProgramDesc BuildProgramDesc() {
44+
ProgramDesc prog;
45+
for (auto& v :
46+
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g"})) {
47+
auto* var = prog.MutableBlock(0)->Var(v);
48+
var->SetType(proto::VarType::SELECTED_ROWS);
49+
if (v == "weights" || v == "bias") {
50+
var->SetPersistable(true);
51+
}
52+
}
53+
54+
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
55+
std::vector<std::string>({"b"}));
56+
SetOp(&prog, "OP1", std::vector<std::string>({"b"}),
57+
std::vector<std::string>({"c"}));
58+
SetOp(&prog, "conv2d", std::vector<std::string>({"c", "weights"}),
59+
std::vector<std::string>({"f"}));
60+
SetOp(&prog, "elementwise_add", std::vector<std::string>({"f", "bias"}),
61+
std::vector<std::string>({"g"}));
62+
63+
return prog;
64+
}
65+
66+
TEST(ConvBiasFusePass, basic) {
67+
auto prog = BuildProgramDesc();
68+
69+
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
70+
71+
auto pass = PassRegistry::Instance().Get("conv_bias_mkldnn_fuse_pass");
72+
73+
int original_nodes_num = graph->Nodes().size();
74+
75+
graph = pass->Apply(std::move(graph));
76+
77+
int current_nodes_num = graph->Nodes().size();
78+
79+
// Remove 3 Nodes: conv, elementwise_add, conv_out
80+
// Add 1 Node: ConvBias
81+
EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
82+
83+
// Assert conv_bias op in newly generated graph
84+
int conv_bias_count = 0;
85+
86+
for (auto* node : graph->Nodes()) {
87+
if (node->IsOp() && node->Op()->Type() == "conv2d") {
88+
if (node->Op()->HasAttr("use_mkldnn")) {
89+
bool use_mkldnn = boost::get<bool>(node->Op()->GetAttr("use_mkldnn"));
90+
if (use_mkldnn) {
91+
auto names = node->Op()->InputNames();
92+
if (std::find(names.begin(), names.end(), "Bias") != names.end()) {
93+
conv_bias_count++;
94+
}
95+
}
96+
}
97+
}
98+
}
99+
EXPECT_EQ(conv_bias_count, 1);
100+
}
101+
102+
} // namespace ir
103+
} // namespace framework
104+
} // namespace paddle
105+
106+
USE_PASS(conv_bias_mkldnn_fuse_pass);

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,38 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
964964
return ele_add_grad;
965965
}
966966

967+
PDNode *patterns::ConvBias::operator()(
968+
paddle::framework::ir::PDNode *conv_input) {
969+
// Create Operators
970+
conv_input->assert_is_op_input("conv2d", "Input");
971+
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
972+
auto *eltiwse_op =
973+
pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add");
974+
// Create variables
975+
// Filter
976+
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
977+
->AsInput()
978+
->assert_is_persistable_var()
979+
->assert_is_op_input("conv2d", "Filter");
980+
// intermediate variable, will be removed in the IR after fuse.
981+
auto *conv_out_var = pattern->NewNode(conv_out_repr())
982+
->AsIntermediate()
983+
->assert_is_only_output_of_op("conv2d")
984+
->assert_is_op_input("elementwise_add");
985+
// Bias stored in elementwise_add
986+
auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr())
987+
->AsInput()
988+
->assert_is_op_input("elementwise_add", "Y");
989+
// output
990+
auto *eltwise_out_var = pattern->NewNode(eltwise_out_repr())
991+
->AsOutput()
992+
->assert_is_op_output("elementwise_add");
993+
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
994+
eltiwse_op->LinksFrom({conv_out_var, eltwise_bias_var})
995+
.LinksTo({eltwise_out_var});
996+
return eltwise_out_var;
997+
}
998+
967999
} // namespace ir
9681000
} // namespace framework
9691001
} // namespace paddle

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,27 @@ struct ElewiseAddActInplaceGrad : public PatternBase {
578578
PATTERN_DECL_NODE(d_ele_y);
579579
PATTERN_DECL_NODE(ele_y);
580580
};
581+
582+
// Conv with Elementwise_add as bias
583+
// op: conv + elementwise_add
584+
// named nodes:
585+
// conv_input, conv_weight,
586+
// conv_out, conv,
587+
// eltwise_bias, eltwise_out,
588+
// elementwise_add
589+
struct ConvBias : public PatternBase {
590+
ConvBias(PDPattern* pattern, const std::string& name_scope)
591+
: PatternBase(pattern, name_scope, "conv_bias") {}
592+
PDNode* operator()(PDNode* conv_input);
593+
// declare operator node's name
594+
PATTERN_DECL_NODE(conv);
595+
PATTERN_DECL_NODE(eltwise);
596+
// declare variable node's name
597+
PATTERN_DECL_NODE(conv_weight);
598+
PATTERN_DECL_NODE(conv_out);
599+
PATTERN_DECL_NODE(eltwise_bias);
600+
PATTERN_DECL_NODE(eltwise_out);
601+
};
581602
} // namespace patterns
582603

583604
// Link two ir::Nodes from each other.

paddle/fluid/inference/analysis/analyzer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class Analyzer : public OrderedRegistry<PassManager> {
7676
"conv_bn_fuse_pass", //
7777
"conv_eltwiseadd_bn_fuse_pass", //
7878
#ifdef PADDLE_WITH_MKLDNN
79+
"conv_bias_mkldnn_fuse_pass", //
7980
"conv_relu_mkldnn_fuse_pass", //
8081
#endif
8182
}};

0 commit comments

Comments
 (0)