Skip to content

Commit c3b70ae

Browse files
Wojciech UssSuperjomn
authored andcommitted
Add MKL-DNN placement pass (#13958)
* add MKL-DNN placement pass This patch also refactors conv+bn (includes changes from PR #13926) updated to use the mkldnn-placement-pass. test=develop * remove redundant pass list * add comment on the default first pass * fix test for conv+relu mkldnn fuse
1 parent 909e134 commit c3b70ae

12 files changed

+301
-66
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function(pass_library TARGET DEST)
1010
set(oneValueArgs "")
1111
set(multiValueArgs SRCS DEPS)
1212
cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
13-
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass ${op_library_DEPS})
13+
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${op_library_DEPS})
1414
# add more DEST here, such as train, dist and collect USE_PASS into a file automatically.
1515
if (${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference")
1616
message(STATUS "add pass ${TARGET} ${DEST}")
@@ -25,20 +25,22 @@ cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
2525
cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
2626
cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
2727
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
28+
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
2829

2930
pass_library(graph_to_program_pass base)
3031
pass_library(graph_viz_pass base)
3132
pass_library(fc_fuse_pass inference)
32-
if (WITH_MKLDNN)
33-
pass_library(conv_relu_mkldnn_fuse_pass inference)
34-
endif ()
3533
pass_library(attention_lstm_fuse_pass inference)
3634
pass_library(infer_clean_graph_pass inference)
3735
pass_library(fc_lstm_fuse_pass inference)
3836
pass_library(embedding_fc_lstm_fuse_pass inference)
3937
pass_library(fc_gru_fuse_pass inference)
4038
pass_library(seq_concat_fc_fuse_pass inference)
4139
pass_library(conv_bn_fuse_pass inference)
40+
if(WITH_MKLDNN)
41+
pass_library(mkldnn_placement_pass base)
42+
pass_library(conv_relu_mkldnn_fuse_pass inference)
43+
endif()
4244

4345
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
4446

paddle/fluid/framework/ir/conv_bn_fuse_pass.cc

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,21 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
126126
// conv, batch_norm,
127127
// conv_weight, conv_out,
128128
// bn_scale, bn_bias, bn_mean, bn_variance,
129-
// bn_out, bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance
129+
// bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,
130+
// bn_saved_variance
130131
GET_CONV_BN_NODES(conv_bn_pattern);
131132

133+
// check if fuse can be done and if MKL-DNN should be used
134+
FuseOptions fuse_option = FindFuseOption(*conv, *batch_norm);
135+
if (fuse_option == DO_NOT_FUSE) {
136+
VLOG(3) << "do not perform conv+bn fuse";
137+
return;
138+
}
139+
132140
// Create eltwise_y (conv bias) variable
133141
VarDesc eltwise_y_in_desc(
134142
patterns::PDNodeName(name_scope_, "eltwise_y_in"));
143+
eltwise_y_in_desc.SetPersistable(true);
135144
auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
136145
auto* eltwise_y_in_tensor =
137146
scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();
@@ -151,27 +160,59 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
151160
*bn_mean, *bn_variance, eltwise_y_in_tensor,
152161
epsilon);
153162

154-
// Create an elementwise add node
155-
OpDesc desc;
156-
desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
157-
desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
158-
desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
159-
desc.SetType("elementwise_add");
160-
desc.SetAttr("axis", 1);
161-
bool a = boost::get<bool>(conv->Op()->GetAttr("use_mkldnn"));
162-
desc.SetAttr("use_mkldnn", a);
163-
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
164-
165-
GraphSafeRemoveNodes(graph.get(), {bn_scale, bn_bias, bn_mean, bn_variance,
166-
batch_norm, bn_mean_out, bn_variance_out,
167-
bn_saved_mean, bn_saved_variance});
168-
169-
PADDLE_ENFORCE(subgraph.count(conv_input));
170-
IR_NODE_LINK_TO(conv_out, eltwise_op);
171-
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
172-
IR_NODE_LINK_TO(eltwise_op, bn_out);
173-
174-
found_conv_bn_count++;
163+
// with MKL-DNN fuse conv+bn into conv with bias
164+
// without MKL-DNN fuse conv+bn into conv+elementwise_add
165+
if (fuse_option == FUSE_MKLDNN) {
166+
auto input_names = conv->Op()->InputNames();
167+
bool has_bias = std::find(input_names.begin(), input_names.end(),
168+
"Bias") != input_names.end();
169+
if (has_bias && conv->Op()->Input("Bias").size() > 0) {
170+
// reuse existing conv bias node
171+
auto conv_bias_names = conv->Op()->Input("Bias");
172+
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1);
173+
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
174+
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
175+
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(),
176+
eltwise_y_in_tensor->dims());
177+
178+
auto eigen_conv_bias = EigenVector<float>::From(*conv_bias_tensor);
179+
eigen_conv_bias += EigenVector<float>::From(*eltwise_y_in_tensor);
180+
} else {
181+
// add new conv_bias node
182+
conv->Op()->SetInput(
183+
"Bias", std::vector<std::string>({eltwise_y_in_node->Name()}));
184+
IR_NODE_LINK_TO(eltwise_y_in_node, conv);
185+
}
186+
conv->Op()->SetOutput("Output",
187+
std::vector<std::string>({bn_out->Name()}));
188+
189+
GraphSafeRemoveNodes(
190+
graph.get(),
191+
{conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
192+
bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance});
193+
194+
IR_NODE_LINK_TO(conv, bn_out);
195+
found_conv_bn_count++;
196+
} else { // fuse_option == FUSE_NATIVE
197+
// create an elementwise add node.
198+
OpDesc desc;
199+
desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
200+
desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
201+
desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
202+
desc.SetType("elementwise_add");
203+
desc.SetAttr("axis", 1);
204+
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
205+
206+
GraphSafeRemoveNodes(
207+
graph.get(),
208+
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
209+
bn_variance_out, bn_saved_mean, bn_saved_variance});
210+
211+
IR_NODE_LINK_TO(conv_out, eltwise_op);
212+
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
213+
IR_NODE_LINK_TO(eltwise_op, bn_out);
214+
found_conv_bn_count++;
215+
}
175216
};
176217

177218
gpd(graph.get(), handler);
@@ -237,7 +278,6 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
237278
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
238279
bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out});
239280

240-
PADDLE_ENFORCE(subgraph.count(conv_input));
241281
IR_NODE_LINK_TO(eltwise, bn_out);
242282

243283
found_conv_bn_count++;

paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
4646
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out
4747
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op
4848

49+
FuseOptions fuse_option = FindFuseOption(*conv, *relu);
50+
if (fuse_option == DO_NOT_FUSE) {
51+
VLOG(3) << "do not perform conv+relu fuse";
52+
return;
53+
}
54+
4955
// Transform Conv node into ConvReLU node.
5056
OpDesc* desc = conv->Op();
5157
desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));

paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,19 @@ namespace paddle {
2020
namespace framework {
2121
namespace ir {
2222

23-
void SetOp(ProgramDesc* prog, const std::string& type,
23+
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
2424
const std::vector<std::string>& inputs,
25-
const std::vector<std::string>& outputs) {
25+
const std::vector<std::string>& outputs, bool use_mkldnn = false) {
2626
auto* op = prog->MutableBlock(0)->AppendOp();
2727
op->SetType(type);
2828
if (type == "conv2d") {
29-
op->SetAttr("use_mkldnn", true);
29+
op->SetAttr("use_mkldnn", use_mkldnn);
30+
op->SetAttr("name", name);
3031
op->SetInput("Input", {inputs[0]});
3132
op->SetInput("Filter", {inputs[1]});
3233
op->SetInput("Bias", {inputs[2]});
3334
} else if (type == "relu") {
35+
op->SetAttr("use_mkldnn", use_mkldnn);
3436
op->SetInput("X", inputs);
3537
}
3638
op->SetOutput("Out", outputs);
@@ -43,22 +45,33 @@ void SetOp(ProgramDesc* prog, const std::string& type,
4345
ProgramDesc BuildProgramDesc() {
4446
ProgramDesc prog;
4547
for (auto& v :
46-
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g"})) {
48+
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g",
49+
"h", "weights2", "bias2", "k", "l"})) {
4750
auto* var = prog.MutableBlock(0)->Var(v);
4851
var->SetType(proto::VarType::SELECTED_ROWS);
4952
if (v == "weights" || v == "bias") {
5053
var->SetPersistable(true);
5154
}
5255
}
5356

54-
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
57+
SetOp(&prog, "OP0", "op0", std::vector<std::string>({"a"}),
5558
std::vector<std::string>({"b"}));
56-
SetOp(&prog, "OP1", std::vector<std::string>({"b"}),
59+
SetOp(&prog, "OP1", "op1", std::vector<std::string>({"b"}),
5760
std::vector<std::string>({"c"}));
58-
SetOp(&prog, "conv2d", std::vector<std::string>({"c", "weights", "bias"}),
59-
std::vector<std::string>({"f"}));
60-
SetOp(&prog, "relu", std::vector<std::string>({"f"}),
61-
std::vector<std::string>({"g"}));
61+
// conv+relu, both with MKL-DNN
62+
SetOp(&prog, "conv2d", "conv1",
63+
std::vector<std::string>({"c", "weights", "bias"}),
64+
std::vector<std::string>({"f"}), true);
65+
SetOp(&prog, "relu", "relu1", std::vector<std::string>({"f"}),
66+
std::vector<std::string>({"g"}), true);
67+
SetOp(&prog, "OP3", "op3", std::vector<std::string>({"g"}),
68+
std::vector<std::string>({"h"}));
69+
// conv+relu, only one with MKL-DNN
70+
SetOp(&prog, "conv2d", "conv2",
71+
std::vector<std::string>({"h", "weights2", "bias2"}),
72+
std::vector<std::string>({"k"}), true);
73+
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"k"}),
74+
std::vector<std::string>({"l"}));
6275

6376
return prog;
6477
}
@@ -88,10 +101,16 @@ TEST(ConvReLUFusePass, basic) {
88101
auto* op = node->Op();
89102
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
90103
EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn")));
91-
ASSERT_TRUE(op->HasAttr("fuse_relu"));
92-
bool fuse_relu = boost::get<bool>(op->GetAttr("fuse_relu"));
93-
if (fuse_relu) {
94-
++conv_relu_count;
104+
// check if only "conv1" convolution is fused
105+
auto op_name = boost::get<std::string>(op->GetAttr("name"));
106+
if (op_name == "conv1") {
107+
ASSERT_TRUE(op->HasAttr("fuse_relu"));
108+
bool fuse_relu = boost::get<bool>(op->GetAttr("fuse_relu"));
109+
if (fuse_relu) {
110+
++conv_relu_count;
111+
}
112+
} else if (op_name == "conv2") {
113+
ASSERT_FALSE(op->HasAttr("fuse_relu"));
95114
}
96115
}
97116
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
16+
17+
namespace paddle {
18+
namespace framework {
19+
namespace ir {
20+
21+
void FusePassBase::Init(const std::string& repr, Graph* graph) const {
22+
repr_ = repr;
23+
graph_ = graph;
24+
}
25+
26+
Scope* FusePassBase::param_scope() const {
27+
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
28+
return graph_->Get<framework::Scope*>(kParamScopeAttr);
29+
}
30+
31+
void FusePassBase::AddStatis(int count_of_fused) const {
32+
PADDLE_ENFORCE(graph_);
33+
PADDLE_ENFORCE(!repr_.empty());
34+
if (!graph_->Has(kFuseStatisAttr)) {
35+
graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>);
36+
}
37+
auto& info =
38+
graph_->Get<std::unordered_map<std::string, int>>(kFuseStatisAttr);
39+
info[repr_] = count_of_fused;
40+
}
41+
42+
FuseOptions FusePassBase::FindFuseOption(const Node& node1,
43+
const Node& node2) const {
44+
#ifdef PADDLE_WITH_MKLDNN
45+
bool node1_mkldnn = node1.Op()->HasAttr("use_mkldnn") &&
46+
boost::get<bool>(node1.Op()->GetAttr("use_mkldnn"));
47+
bool node2_mkldnn = node2.Op()->HasAttr("use_mkldnn") &&
48+
boost::get<bool>(node2.Op()->GetAttr("use_mkldnn"));
49+
if (node1_mkldnn && node2_mkldnn)
50+
return FUSE_MKLDNN;
51+
else if (!node1_mkldnn && !node2_mkldnn)
52+
return FUSE_NATIVE;
53+
else
54+
return DO_NOT_FUSE;
55+
#else
56+
return FUSE_NATIVE;
57+
#endif
58+
};
59+
60+
} // namespace ir
61+
} // namespace framework
62+
} // namespace paddle

paddle/fluid/framework/ir/fuse_pass_base.h

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,32 +25,24 @@ namespace ir {
2525
static const char kParamScopeAttr[] = "__param_scope__";
2626
static const char kFuseStatisAttr[] = "__fuse_statis__";
2727

28+
enum FuseOptions {
29+
DO_NOT_FUSE, // fusing will not be done
30+
FUSE_NATIVE, // fusing will be done without MKL-DNN
31+
FUSE_MKLDNN // fusing will be done with MKL-DNN
32+
};
33+
2834
class FusePassBase : public Pass {
2935
public:
30-
void Init(const std::string& repr, Graph* graph) const {
31-
repr_ = repr;
32-
graph_ = graph;
33-
}
34-
35-
Scope* param_scope() const {
36-
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
37-
return graph_->Get<framework::Scope*>(kParamScopeAttr);
38-
}
39-
40-
void AddStatis(int count_of_fused) const {
41-
PADDLE_ENFORCE(graph_);
42-
PADDLE_ENFORCE(!repr_.empty());
43-
if (!graph_->Has(kFuseStatisAttr)) {
44-
graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>);
45-
}
46-
auto& info =
47-
graph_->Get<std::unordered_map<std::string, int>>(kFuseStatisAttr);
48-
info[repr_] = count_of_fused;
49-
}
36+
void Init(const std::string& repr, Graph* graph) const;
37+
Scope* param_scope() const;
38+
void AddStatis(int count_of_fused) const;
5039

5140
virtual ~FusePassBase() {}
5241

5342
protected:
43+
virtual FuseOptions FindFuseOption(const Node& node1,
44+
const Node& node2) const;
45+
5446
mutable Graph* graph_;
5547
mutable std::string repr_;
5648
};
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/ir/mkldnn_placement_pass.h"
16+
17+
namespace paddle {
18+
namespace framework {
19+
namespace ir {
20+
21+
std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
22+
std::unique_ptr<ir::Graph> graph) const {
23+
VLOG(3) << "Aplies MKL-DNN placement strategy.";
24+
for (const Node* n : graph->Nodes()) {
25+
if (n->IsOp() && n->Op()->HasAttr("use_mkldnn")) {
26+
n->Op()->SetAttr("use_mkldnn", true);
27+
}
28+
}
29+
return graph;
30+
}
31+
32+
} // namespace ir
33+
} // namespace framework
34+
} // namespace paddle
35+
36+
REGISTER_PASS(mkldnn_placement_pass,
37+
paddle::framework::ir::MKLDNNPlacementPass);

0 commit comments

Comments
 (0)