Skip to content

Commit d7509d6

Browse files
committed
Conv+Bias: Support non-null bias
test=develop
1 parent 91e8fba commit d7509d6

File tree

5 files changed

+82
-134
lines changed

5 files changed

+82
-134
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,5 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph
5656
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
5757
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
5858
if (WITH_MKLDNN)
59-
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass)
6059
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
6160
endif ()

paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc

Lines changed: 79 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,48 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14+
1415
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
16+
#include <functional>
1517
#include <string>
1618
#include <vector>
19+
#include "paddle/fluid/framework/lod_tensor.h"
1720
#include "paddle/fluid/platform/enforce.h"
21+
1822
namespace paddle {
1923
namespace framework {
2024
namespace ir {
25+
26+
template <typename BinaryOperation>
27+
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
28+
BinaryOperation f) {
29+
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims());
30+
LoDTensor vec_y;
31+
vec_y.Resize(vec_a.dims());
32+
const float* a = vec_a.data<float>();
33+
const float* b = vec_b.data<float>();
34+
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
35+
for (int i = 0; i < vec_a.numel(); i++) {
36+
y[i] = f(a[i], b[i]);
37+
}
38+
return vec_y;
39+
}
40+
2141
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
2242
std::unique_ptr<ir::Graph> graph) const {
2343
PADDLE_ENFORCE(graph.get());
24-
FusePassBase::Init("conv_bias_mkldnn_fuse", graph.get());
44+
FusePassBase::Init(name_scope_, graph.get());
45+
46+
auto* scope = param_scope();
47+
PADDLE_ENFORCE(scope);
48+
2549
GraphPatternDetector gpd;
26-
auto* conv_input = gpd.mutable_pattern()
27-
->NewNode("conv_bias_mkldnn_fuse/conv_input")
28-
->AsInput()
29-
->assert_is_op_input("conv2d", "Input");
30-
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(),
31-
"conv_bias_mkldnn_fuse");
50+
auto* conv_input =
51+
gpd.mutable_pattern()
52+
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
53+
->AsInput()
54+
->assert_is_op_input("conv2d", "Input");
55+
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
3256
conv_bias_pattern(conv_input);
3357
int found_conv_bias_count = 0;
3458
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
@@ -44,27 +68,55 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
4468
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bias_pattern);
4569
// elementwise_add op
4670
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern);
47-
// Create an ConvBias Node.
48-
OpDesc desc;
49-
std::string conv_bias_i_in = subgraph.at(conv_input)->Name();
50-
std::string conv_bias_w_in = conv_weight->Name();
51-
std::string conv_bias_b_in = eltwise_bias->Name();
52-
std::string conv_bias_out = eltwise_out->Name();
53-
desc.SetInput("Input", std::vector<std::string>({conv_bias_i_in}));
54-
desc.SetInput("Filter", std::vector<std::string>({conv_bias_w_in}));
55-
desc.SetInput("Bias", std::vector<std::string>({conv_bias_b_in}));
56-
desc.SetOutput("Output", std::vector<std::string>({conv_bias_out}));
57-
desc.SetType("conv2d");
58-
for (auto& attr : conv->Op()->GetAttrMap()) {
59-
desc.SetAttr(attr.first, attr.second);
60-
}
61-
auto conv_bias_node = g->CreateOpNode(&desc); // OpDesc will be copied.
62-
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
71+
6372
PADDLE_ENFORCE(subgraph.count(conv_input));
64-
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node);
65-
IR_NODE_LINK_TO(conv_weight, conv_bias_node);
66-
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
67-
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
73+
74+
auto* eltwise_bias_tensor =
75+
scope->FindVar(eltwise_bias->Name())->GetMutable<LoDTensor>();
76+
77+
auto input_names = conv->Op()->InputNames();
78+
bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") !=
79+
input_names.end();
80+
if (has_bias && conv->Op()->Input("Bias").size() > 0) {
81+
auto conv_bias_names = conv->Op()->Input("Bias");
82+
// add eltwise bias to existing conv bias
83+
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1);
84+
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
85+
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
86+
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), eltwise_bias_tensor->dims());
87+
*conv_bias_tensor = tensor_apply_eltwise(
88+
*conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>());
89+
90+
conv->Op()->SetOutput("Output",
91+
std::vector<std::string>({eltwise_out->Name()}));
92+
93+
GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out});
94+
95+
IR_NODE_LINK_TO(conv, eltwise_out);
96+
} else {
97+
// take eltwise bias as conv bias
98+
OpDesc desc;
99+
100+
desc.SetInput(
101+
"Input", std::vector<std::string>({subgraph.at(conv_input)->Name()}));
102+
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
103+
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
104+
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
105+
desc.SetType("conv2d");
106+
107+
for (auto& attr : conv->Op()->GetAttrMap()) {
108+
desc.SetAttr(attr.first, attr.second);
109+
}
110+
auto conv_bias_node = g->CreateOpNode(&desc);
111+
112+
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node);
113+
IR_NODE_LINK_TO(conv_weight, conv_bias_node);
114+
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
115+
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
116+
117+
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
118+
}
119+
68120
found_conv_bias_count++;
69121
};
70122
gpd(graph.get(), handler);

paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
#pragma once
15+
#include <string>
1516
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
1617
#include "paddle/fluid/framework/ir/graph.h"
1718
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
@@ -28,6 +29,7 @@ class ConvBiasFusePass : public FusePassBase {
2829

2930
protected:
3031
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
32+
const std::string name_scope_{"conv_bias_mkldnn_fuse"};
3133
};
3234
} // namespace ir
3335
} // namespace framework

paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass_tester.cc

Lines changed: 0 additions & 106 deletions
This file was deleted.

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,7 @@ PDNode *patterns::ConvBias::operator()(
987987
// Bias stored in elementwise_add
988988
auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr())
989989
->AsInput()
990+
->assert_is_persistable_var()
990991
->assert_is_op_input("elementwise_add", "Y");
991992
// output
992993
auto *eltwise_out_var = pattern->NewNode(eltwise_out_repr())

0 commit comments

Comments
 (0)