Skip to content

Commit cbbacb2

Browse files
committed
Merge remote-tracking branch 'ups/develop' into fea/fusion_seqconv_add
test=develop
2 parents 603ba5e + da722d6 commit cbbacb2

File tree

8 files changed

+231
-4
lines changed

8 files changed

+231
-4
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ pass_library(conv_bn_fuse_pass inference)
4040
pass_library(seqconv_eltadd_relu_fuse_pass inference)
4141
if(WITH_MKLDNN)
4242
pass_library(mkldnn_placement_pass base)
43+
pass_library(conv_bias_mkldnn_fuse_pass inference)
4344
pass_library(conv_relu_mkldnn_fuse_pass inference)
4445
endif()
4546

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
16+
#include <functional>
17+
#include <string>
18+
#include <vector>
19+
#include "paddle/fluid/framework/lod_tensor.h"
20+
#include "paddle/fluid/platform/enforce.h"
21+
22+
namespace paddle {
23+
namespace framework {
24+
namespace ir {
25+
26+
template <typename BinaryOperation>
27+
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
28+
BinaryOperation f) {
29+
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims());
30+
LoDTensor vec_y;
31+
vec_y.Resize(vec_a.dims());
32+
const float* a = vec_a.data<float>();
33+
const float* b = vec_b.data<float>();
34+
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
35+
for (int i = 0; i < vec_a.numel(); i++) {
36+
y[i] = f(a[i], b[i]);
37+
}
38+
return vec_y;
39+
}
40+
41+
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
42+
std::unique_ptr<ir::Graph> graph) const {
43+
PADDLE_ENFORCE(graph.get());
44+
FusePassBase::Init(name_scope_, graph.get());
45+
46+
auto* scope = param_scope();
47+
PADDLE_ENFORCE(scope);
48+
49+
GraphPatternDetector gpd;
50+
auto* conv_input =
51+
gpd.mutable_pattern()
52+
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
53+
->AsInput()
54+
->assert_is_op_input("conv2d", "Input");
55+
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
56+
conv_bias_pattern(conv_input);
57+
int found_conv_bias_count = 0;
58+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
59+
Graph* g) {
60+
VLOG(4) << "handle ConvBias fuse";
61+
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
62+
conv_bias_pattern); // Filter
63+
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_bias_pattern); // tmp
64+
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_bias_pattern); // CONV op
65+
// bias
66+
GET_IR_NODE_FROM_SUBGRAPH(eltwise_bias, eltwise_bias, conv_bias_pattern);
67+
// output
68+
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bias_pattern);
69+
// elementwise_add op
70+
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern);
71+
72+
PADDLE_ENFORCE(subgraph.count(conv_input));
73+
74+
// check if fuse can be done and if MKL-DNN should be used
75+
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
76+
if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) {
77+
VLOG(3) << "do not perform conv+bias fuse";
78+
return;
79+
}
80+
81+
auto* eltwise_bias_tensor =
82+
scope->FindVar(eltwise_bias->Name())->GetMutable<LoDTensor>();
83+
84+
auto input_names = conv->Op()->InputNames();
85+
bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") !=
86+
input_names.end();
87+
if (has_bias && conv->Op()->Input("Bias").size() > 0) {
88+
auto conv_bias_names = conv->Op()->Input("Bias");
89+
// add eltwise bias to existing conv bias
90+
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1);
91+
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
92+
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
93+
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), eltwise_bias_tensor->dims());
94+
*conv_bias_tensor = tensor_apply_eltwise(
95+
*conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>());
96+
97+
conv->Op()->SetOutput("Output",
98+
std::vector<std::string>({eltwise_out->Name()}));
99+
100+
GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out});
101+
102+
IR_NODE_LINK_TO(conv, eltwise_out);
103+
} else {
104+
// take eltwise bias as conv bias
105+
OpDesc desc;
106+
107+
desc.SetInput(
108+
"Input", std::vector<std::string>({subgraph.at(conv_input)->Name()}));
109+
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
110+
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
111+
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
112+
desc.SetType("conv2d");
113+
114+
for (auto& attr : conv->Op()->GetAttrMap()) {
115+
desc.SetAttr(attr.first, attr.second);
116+
}
117+
auto conv_bias_node = g->CreateOpNode(&desc);
118+
119+
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node);
120+
IR_NODE_LINK_TO(conv_weight, conv_bias_node);
121+
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
122+
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
123+
124+
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
125+
}
126+
127+
found_conv_bias_count++;
128+
};
129+
gpd(graph.get(), handler);
130+
AddStatis(found_conv_bias_count);
131+
return graph;
132+
}
133+
} // namespace ir
134+
} // namespace framework
135+
} // namespace paddle
136+
REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
137+
paddle::framework::ir::ConvBiasFusePass);
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#pragma once
15+
#include <string>
16+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
17+
#include "paddle/fluid/framework/ir/graph.h"
18+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
19+
#include "paddle/fluid/framework/ir/pass.h"
20+
namespace paddle {
21+
namespace framework {
22+
namespace ir {
23+
/*
24+
* Fuse the Conv and Elementwise_add to a ConvBiasOp.
25+
*/
26+
class ConvBiasFusePass : public FusePassBase {
27+
public:
28+
virtual ~ConvBiasFusePass() {}
29+
30+
protected:
31+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
32+
const std::string name_scope_{"conv_bias_mkldnn_fuse"};
33+
};
34+
} // namespace ir
35+
} // namespace framework
36+
} // namespace paddle

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,39 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
10161016
return ele_add_grad;
10171017
}
10181018

1019+
PDNode *patterns::ConvBias::operator()(
1020+
paddle::framework::ir::PDNode *conv_input) {
1021+
// Create Operators
1022+
conv_input->assert_is_op_input("conv2d", "Input");
1023+
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
1024+
auto *eltiwse_op =
1025+
pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add");
1026+
// Create variables
1027+
// Filter
1028+
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
1029+
->AsInput()
1030+
->assert_is_persistable_var()
1031+
->assert_is_op_input("conv2d", "Filter");
1032+
// intermediate variable, will be removed in the IR after fuse.
1033+
auto *conv_out_var = pattern->NewNode(conv_out_repr())
1034+
->AsIntermediate()
1035+
->assert_is_only_output_of_op("conv2d")
1036+
->assert_is_op_input("elementwise_add");
1037+
// Bias stored in elementwise_add
1038+
auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr())
1039+
->AsInput()
1040+
->assert_is_persistable_var()
1041+
->assert_is_op_input("elementwise_add", "Y");
1042+
// output
1043+
auto *eltwise_out_var = pattern->NewNode(eltwise_out_repr())
1044+
->AsOutput()
1045+
->assert_is_op_output("elementwise_add");
1046+
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
1047+
eltiwse_op->LinksFrom({conv_out_var, eltwise_bias_var})
1048+
.LinksTo({eltwise_out_var});
1049+
return eltwise_out_var;
1050+
}
1051+
10191052
} // namespace ir
10201053
} // namespace framework
10211054
} // namespace paddle

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,27 @@ struct ElewiseAddActInplaceGrad : public PatternBase {
603603
PATTERN_DECL_NODE(d_ele_y);
604604
PATTERN_DECL_NODE(ele_y);
605605
};
606+
607+
// Conv with Elementwise_add as bias
608+
// op: conv + elementwise_add
609+
// named nodes:
610+
// conv_input, conv_weight,
611+
// conv_out, conv,
612+
// eltwise_bias, eltwise_out,
613+
// elementwise_add
614+
struct ConvBias : public PatternBase {
615+
ConvBias(PDPattern* pattern, const std::string& name_scope)
616+
: PatternBase(pattern, name_scope, "conv_bias") {}
617+
PDNode* operator()(PDNode* conv_input);
618+
// declare operator node's name
619+
PATTERN_DECL_NODE(conv);
620+
PATTERN_DECL_NODE(eltwise);
621+
// declare variable node's name
622+
PATTERN_DECL_NODE(conv_weight);
623+
PATTERN_DECL_NODE(conv_out);
624+
PATTERN_DECL_NODE(eltwise_bias);
625+
PATTERN_DECL_NODE(eltwise_out);
626+
};
606627
} // namespace patterns
607628

608629
// Link two ir::Nodes from each other.

paddle/fluid/inference/analysis/analyzer.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,12 @@ Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
101101

102102
void Analyzer::Run(Argument* argument) {
103103
std::vector<std::string> passes;
104+
#ifdef PADDLE_WITH_MKLDNN
104105
if (use_mkldnn_) {
105106
VLOG(3) << "Adding MKL-DNN placement pass";
106107
passes.push_back("mkldnn_placement_pass");
107108
}
109+
#endif
108110
for (auto& pass : ir_passes_) {
109111
if (!disabled_ir_passes_.count(pass)) {
110112
passes.push_back(pass);

paddle/fluid/inference/analysis/analyzer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class Analyzer : public OrderedRegistry<PassManager> {
8080
"conv_bn_fuse_pass", //
8181
"conv_eltwiseadd_bn_fuse_pass", //
8282
#ifdef PADDLE_WITH_MKLDNN
83+
"conv_bias_mkldnn_fuse_pass", //
8384
"conv_relu_mkldnn_fuse_pass", //
8485
#endif
8586
}};

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,6 @@ bool AnalysisPredictor::Init(
7777
inference_program_ = program;
7878
}
7979

80-
if (config_._use_mkldnn) {
81-
executor_->EnableMKLDNN(*inference_program_);
82-
}
83-
8480
executor_->Prepare(scope_.get(), *inference_program_, 0,
8581
config_.use_feed_fetch_ops);
8682

0 commit comments

Comments
 (0)