Skip to content

Commit ce2464f

Browse files
author
Tomasz Patejko
committed
MKLDNN conv + elementwise_add fusion: UT for missing bias added. UTs refactored. Some minor changes in the pass
1 parent 4e72ab4 commit ce2464f

File tree

4 files changed

+99
-111
lines changed

4 files changed

+99
-111
lines changed

paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
6868

6969
conv_output->AsIntermediate();
7070

71-
auto conv_op_has_bias = [](const Node& conv_op,
72-
const Scope& scope) -> std::pair<bool, Node*> {
71+
auto conv_op_has_bias = [](const Node& conv_op) -> std::pair<bool, Node*> {
7372
auto bias_input_names = conv_op.Op()->Inputs();
7473
auto bias_it = bias_input_names.find("Bias");
7574

@@ -116,7 +115,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
116115
bool has_bias;
117116
Node* conv_bias;
118117

119-
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op, *param_scope());
118+
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op);
120119

121120
if (has_bias) {
122121
op_desc.SetInput("Bias", {conv_bias->Name()});

paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc

Lines changed: 96 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,22 @@ namespace paddle {
2222
namespace framework {
2323
namespace ir {
2424

25+
namespace {
2526
constexpr int nodes_removed = 3;
2627
constexpr int nodes_added = 1;
2728

2829
void SetOp(ProgramDesc* prog, const std::string& type,
29-
const std::vector<std::string>& inputs,
30-
const std::vector<std::string>& outputs) {
30+
const std::vector<std::pair<std::string, std::string>>& inputs,
31+
const std::pair<std::string, std::string>& output) {
3132
auto op = prog->MutableBlock(0)->AppendOp();
3233
op->SetType(type);
34+
op->SetAttr("use_mkldnn", true);
3335

34-
if (type == "conv2d") {
35-
op->SetAttr("use_mkldnn", true);
36-
op->SetInput("Input", {inputs[0]});
37-
op->SetInput("Bias", {inputs[1]});
38-
op->SetInput("Filter", {inputs[2]});
39-
op->SetOutput("Output", outputs);
40-
} else if (type == "elementwise_add") {
41-
op->SetInput("X", {inputs[0]});
42-
op->SetInput("Y", {inputs[1]});
43-
op->SetOutput("Out", outputs);
44-
} else if (type == "relu" || type == "sigmoid") {
45-
op->SetInput("X", {inputs[0]});
46-
op->SetOutput("Out", outputs);
36+
for (const auto& input : inputs) {
37+
op->SetInput(input.first, {input.second});
4738
}
39+
40+
op->SetOutput(output.first, {output.second});
4841
}
4942

5043
struct IsReachable {
@@ -96,30 +89,59 @@ struct IsReachable {
9689
}
9790
};
9891

99-
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
100-
auto build_program_desc = [&]() -> ProgramDesc {
101-
ProgramDesc prog;
102-
for (auto& v : std::vector<std::string>(
103-
{"a", "b", "bias", "weights", "c", "d", "e", "f"})) {
104-
auto* var = prog.MutableBlock(0)->Var(v);
105-
var->SetType(proto::VarType::LOD_TENSOR);
106-
if (v == "weights" || v == "bias") {
107-
var->SetPersistable(true);
108-
}
92+
void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) {
93+
int conv_count = 0;
94+
int elementwise_add_count = 0;
95+
96+
for (auto* node : graph->Nodes()) {
97+
if (node->IsOp() && node->Op()->Type() == "conv2d") {
98+
++conv_count;
99+
}
100+
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
101+
++elementwise_add_count;
109102
}
103+
}
104+
EXPECT_EQ(conv_count, 1);
105+
EXPECT_EQ(elementwise_add_count, 0);
106+
}
107+
108+
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
109+
const std::vector<std::string>& persistent_vars) {
110+
ProgramDesc prog;
110111

111-
SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"});
112-
SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
113-
SetOp(&prog, "relu", {"d"}, {"e"});
112+
auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* {
113+
auto var = prog.MutableBlock(0)->Var(var_name);
114+
var->SetType(proto::VarType::LOD_TENSOR);
114115

115-
return prog;
116+
return var;
116117
};
117118

118-
auto prog = build_program_desc();
119+
for (const auto& v : transient_vars) {
120+
add_var_to_prog(v);
121+
}
122+
123+
for (const auto& v : persistent_vars) {
124+
auto var = add_var_to_prog(v);
125+
var->SetPersistable(true);
126+
}
127+
128+
return prog;
129+
}
130+
} // namespace
131+
132+
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
133+
auto prog =
134+
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
135+
136+
SetOp(&prog, "conv2d",
137+
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
138+
{"Output", "b"});
139+
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
140+
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
141+
119142
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
120143

121144
IsReachable is_reachable;
122-
123145
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
124146

125147
auto pass =
@@ -132,40 +154,45 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
132154

133155
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
134156
current_nodes_num);
135-
// Assert conv_relu op in newly generated graph
136-
int conv_count = 0;
137-
int elementwise_add_count = 0;
138157

139-
for (auto* node : graph->Nodes()) {
140-
if (node->IsOp() && node->Op()->Type() == "conv2d") {
141-
++conv_count;
142-
}
143-
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
144-
++elementwise_add_count;
145-
}
146-
}
147-
EXPECT_EQ(conv_count, 1);
148-
EXPECT_EQ(elementwise_add_count, 0);
158+
AssertOpsCount(graph);
149159
}
150160

151-
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
152-
auto build_program_desc = [&]() -> ProgramDesc {
153-
ProgramDesc prog;
154-
for (auto& v : std::vector<std::string>({"a", "b", "bias", "weights"})) {
155-
auto* var = prog.MutableBlock(0)->Var(v);
156-
var->SetType(proto::VarType::LOD_TENSOR);
157-
if (v == "weights" || v == "bias") {
158-
var->SetPersistable(true);
159-
}
160-
}
161+
TEST(ConvElementwiseAddMKLDNNFusePass,
162+
ConvolutionWithElementwiseAddReluNoBias) {
163+
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
164+
SetOp(&prog, "conv2d", {{"Input", "a"}, {"Filter", "weights"}},
165+
{"Output", "b"});
166+
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
167+
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
161168

162-
SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"});
163-
SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
169+
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
164170

165-
return prog;
166-
};
171+
IsReachable is_reachable;
172+
173+
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
174+
175+
auto pass =
176+
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
177+
int original_nodes_num = graph->Nodes().size();
178+
graph = pass->Apply(std::move(graph));
179+
int current_nodes_num = graph->Nodes().size();
180+
181+
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
182+
183+
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
184+
current_nodes_num);
185+
186+
AssertOpsCount(graph);
187+
}
188+
189+
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
190+
auto prog = BuildProgramDesc({"a", "b", "c", "d"}, {"bias", "weights"});
191+
SetOp(&prog, "conv2d",
192+
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
193+
{"Output", "b"});
194+
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
167195

168-
auto prog = build_program_desc();
169196
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
170197

171198
IsReachable is_reachable;
@@ -181,43 +208,19 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
181208

182209
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
183210
current_nodes_num);
184-
// Assert conv_relu op in newly generated graph
185-
int conv_count = 0;
186-
int elementwise_add_count = 0;
187-
188-
for (auto* node : graph->Nodes()) {
189-
if (node->IsOp() && node->Op()->Type() == "conv2d") {
190-
++conv_count;
191-
}
192-
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
193-
++elementwise_add_count;
194-
}
195-
}
196-
EXPECT_EQ(conv_count, 1);
197-
EXPECT_EQ(elementwise_add_count, 0);
211+
AssertOpsCount(graph);
198212
}
199213

200214
TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
201-
auto build_program_desc = [&]() -> ProgramDesc {
202-
ProgramDesc prog;
203-
for (auto& v : std::vector<std::string>(
204-
{"a", "b", "bias", "weights", "c", "d", "e", "f"})) {
205-
auto* var = prog.MutableBlock(0)->Var(v);
206-
var->SetType(proto::VarType::LOD_TENSOR);
207-
if (v.find("weights") || v.find("bias")) {
208-
var->SetPersistable(true);
209-
}
210-
}
211-
212-
SetOp(&prog, "sigmoid", {"a"}, {"b"});
213-
SetOp(&prog, "conv2d", {"b", "bias", "weights"}, {"c"});
214-
SetOp(&prog, "elementwise_add", {"c", "d"}, {"e"});
215-
SetOp(&prog, "relu", {"e"}, {"f"});
216-
217-
return prog;
218-
};
215+
auto prog =
216+
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
217+
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
218+
SetOp(&prog, "conv2d",
219+
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
220+
{"Output", "c"});
221+
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "d"}}, {"Out", "e"});
222+
SetOp(&prog, "relu", {{"X", "e"}}, {"Out", "f"});
219223

220-
auto prog = build_program_desc();
221224
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
222225

223226
IsReachable is_reachable;
@@ -234,20 +237,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
234237

235238
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
236239
current_nodes_num);
237-
// Assert conv_relu op in newly generated graph
238-
int conv_count = 0;
239-
int elementwise_add_count = 0;
240-
241-
for (auto* node : graph->Nodes()) {
242-
if (node->IsOp() && node->Op()->Type() == "conv2d") {
243-
++conv_count;
244-
}
245-
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
246-
++elementwise_add_count;
247-
}
248-
}
249-
EXPECT_EQ(conv_count, 1);
250-
EXPECT_EQ(elementwise_add_count, 0);
240+
AssertOpsCount(graph);
251241
}
252242

253243
} // namespace ir

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,7 @@ PDNode *patterns::Conv::operator()() {
10141014
->AsOutput()
10151015
->assert_is_op_output("conv2d", "Output");
10161016

1017-
conv_op->LinksFrom({input_var, /*bias_var,*/ filter_var});
1017+
conv_op->LinksFrom({input_var, filter_var});
10181018
conv_op->LinksTo({output_var});
10191019

10201020
return output_var;

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,6 @@ struct Conv : public PatternBase {
617617

618618
PATTERN_DECL_NODE(conv_op);
619619
PATTERN_DECL_NODE(conv_input);
620-
PATTERN_DECL_NODE(conv_bias);
621620
PATTERN_DECL_NODE(conv_filter);
622621
PATTERN_DECL_NODE(conv_residual_data);
623622
PATTERN_DECL_NODE(conv_output);

0 commit comments

Comments
 (0)