Skip to content

Commit 4e2aaf0

Browse files
committed
add depthwise conv mkldnn pass
added depthwise conv mkldnn pass which for MKLDNN changes depthwise_conv operator to conv operator because for mkldnn this is the same api test=develop
1 parent e74267a commit 4e2aaf0

File tree

6 files changed

+220
-1
lines changed

6 files changed

+220
-1
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ pass_library(conv_bn_fuse_pass inference)
4141
pass_library(seqconv_eltadd_relu_fuse_pass inference)
4242
if(WITH_MKLDNN)
4343
pass_library(mkldnn_placement_pass base)
44+
pass_library(depthwise_conv_mkldnn_pass base)
4445
pass_library(conv_bias_mkldnn_fuse_pass inference)
4546
pass_library(conv_relu_mkldnn_fuse_pass inference)
4647
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference)
@@ -59,6 +60,7 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph
5960
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
6061
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
6162
if (WITH_MKLDNN)
63+
cc_test(test_depthwise_conv_mkldnn_pass SRCS depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
6264
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
6365
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
6466
endif ()

paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ class ConvReLUFusePass : public FusePassBase {
3131
virtual ~ConvReLUFusePass() {}
3232

3333
protected:
34-
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
34+
std::unique_ptr<ir::Graph> ApplyImpl(
35+
std::unique_ptr<ir::Graph> graph) const override;
3536
};
3637

3738
} // namespace ir
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h"
16+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
17+
18+
namespace paddle {
19+
namespace framework {
20+
namespace ir {
21+
22+
#define GET_NODE(id, pattern) \
23+
PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \
24+
"pattern has no Node called %s", #id); \
25+
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
26+
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
27+
28+
std::unique_ptr<ir::Graph> DepthwiseConvMKLDNNPass::ApplyImpl(
29+
std::unique_ptr<ir::Graph> graph) const {
30+
PADDLE_ENFORCE(graph.get());
31+
FusePassBase::Init("depthwise_conv_mkldnn_pass", graph.get());
32+
GraphPatternDetector gpd;
33+
34+
auto* pattern = gpd.mutable_pattern();
35+
pattern->NewNode("depthwise_conv")
36+
->assert_is_op("depthwise_conv2d")
37+
->assert_op_attr("use_mkldnn", true);
38+
39+
int found_depthwise_conv_mkldnn_count = 0;
40+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
41+
Graph* g) {
42+
VLOG(3) << "handle DepthwiseConvMKLDNN fuse";
43+
GET_NODE(depthwise_conv, (*pattern));
44+
depthwise_conv->Op()->SetType("conv2d");
45+
found_depthwise_conv_mkldnn_count++;
46+
};
47+
48+
gpd(graph.get(), handler);
49+
AddStatis(found_depthwise_conv_mkldnn_count);
50+
return graph;
51+
}
52+
53+
} // namespace ir
54+
} // namespace framework
55+
} // namespace paddle
56+
57+
REGISTER_PASS(depthwise_conv_mkldnn_pass,
58+
paddle::framework::ir::DepthwiseConvMKLDNNPass);
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
class DepthwiseConvMKLDNNPass : public FusePassBase {
24+
public:
25+
virtual ~DepthwiseConvMKLDNNPass() {}
26+
27+
protected:
28+
std::unique_ptr<ir::Graph> ApplyImpl(
29+
std::unique_ptr<ir::Graph> graph) const override;
30+
};
31+
32+
} // namespace ir
33+
} // namespace framework
34+
} // namespace paddle
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h"
16+
17+
#include <gtest/gtest.h>
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
24+
const std::vector<std::string>& inputs,
25+
const std::vector<std::string>& outputs, bool use_mkldnn = false) {
26+
auto* op = prog->MutableBlock(0)->AppendOp();
27+
op->SetType(type);
28+
op->SetAttr("use_mkldnn", use_mkldnn);
29+
op->SetAttr("name", name);
30+
op->SetInput("Input", {inputs[0]});
31+
op->SetInput("Filter", {inputs[1]});
32+
op->SetInput("Bias", {inputs[2]});
33+
op->SetOutput("Out", outputs);
34+
}
35+
36+
// (a, weights, bias)->depthwise conv mkldnn->b
37+
// (b, weights2, bias2)->depthwise conv no mkldnn->c
38+
// (c, weights3, bias3)->conv mkldnn->d
39+
// (d, weights3, bias3)->conv no mkldnn->e
40+
ProgramDesc BuildProgramDesc() {
41+
ProgramDesc prog;
42+
for (auto& v : std::vector<std::string>(
43+
{"a", "b", "c", "d", "e", "weights", "bias", "weights2", "bias2",
44+
"weights3", "bias3", "weights4", "bias4"})) {
45+
auto* var = prog.MutableBlock(0)->Var(v);
46+
var->SetType(proto::VarType::SELECTED_ROWS);
47+
if (v == "weights" || v == "bias" || v == "weights2" || v == "bias2" ||
48+
v == "weights3" || v == "bias3" || v == "weights4" || v == "bias4") {
49+
var->SetPersistable(true);
50+
}
51+
}
52+
53+
// depthwise conv with MKL-DNN
54+
SetOp(&prog, "depthwise_conv2d", "conv1",
55+
std::vector<std::string>({"a", "weights", "bias"}),
56+
std::vector<std::string>({"b"}), true);
57+
// depthwise conv without MKL-DNN
58+
SetOp(&prog, "depthwise_conv2d", "conv2",
59+
std::vector<std::string>({"b", "weights2", "bias2"}),
60+
std::vector<std::string>({"c"}), false);
61+
// conv with MKL-DNN
62+
SetOp(&prog, "conv2d", "conv3",
63+
std::vector<std::string>({"c", "weights3", "bias3"}),
64+
std::vector<std::string>({"d"}), true);
65+
// conv without MKL-dNN
66+
SetOp(&prog, "conv2d", "conv4",
67+
std::vector<std::string>({"d", "weights4", "bias4"}),
68+
std::vector<std::string>({"e"}), false);
69+
70+
return prog;
71+
}
72+
73+
TEST(DepthwiseConvMKLDNNPass, basic) {
74+
auto prog = BuildProgramDesc();
75+
76+
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
77+
78+
auto pass = PassRegistry::Instance().Get("depthwise_conv_mkldnn_pass");
79+
80+
struct counters {
81+
int mkldnn_depthwise_conv_nodes;
82+
int other_depthwise_conv_nodes;
83+
int mkldnn_conv_nodes;
84+
int other_conv_nodes;
85+
};
86+
87+
counters before{1, 1, 1, 1};
88+
89+
graph = pass->Apply(std::move(graph));
90+
91+
// initialize counters before loop
92+
counters after{0, 0, 0, 0};
93+
94+
for (auto* node : graph->Nodes()) {
95+
if (node->IsOp()) {
96+
auto* op = node->Op();
97+
if (op->Type() == "conv2d") {
98+
if (boost::get<bool>(op->GetAttr("use_mkldnn")))
99+
after.mkldnn_conv_nodes++;
100+
else
101+
after.other_conv_nodes++;
102+
} else if (op->Type() == "depthwise_conv2d") {
103+
if (boost::get<bool>(op->GetAttr("use_mkldnn")))
104+
after.mkldnn_depthwise_conv_nodes++;
105+
else
106+
after.other_depthwise_conv_nodes++;
107+
}
108+
}
109+
}
110+
111+
EXPECT_EQ(after.other_depthwise_conv_nodes,
112+
before.other_depthwise_conv_nodes);
113+
EXPECT_EQ(after.other_conv_nodes, before.other_conv_nodes);
114+
EXPECT_EQ(after.mkldnn_depthwise_conv_nodes,
115+
before.mkldnn_depthwise_conv_nodes - 1);
116+
EXPECT_EQ(after.mkldnn_conv_nodes, before.mkldnn_conv_nodes + 1);
117+
}
118+
119+
} // namespace ir
120+
} // namespace framework
121+
} // namespace paddle
122+
123+
USE_PASS(depthwise_conv_mkldnn_pass);

paddle/fluid/inference/analysis/analyzer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class Analyzer : public OrderedRegistry<PassManager> {
7979
"conv_bn_fuse_pass", //
8080
"conv_eltwiseadd_bn_fuse_pass", //
8181
#ifdef PADDLE_WITH_MKLDNN
82+
"depthwise_conv_mkldnn_pass", //
8283
"conv_bias_mkldnn_fuse_pass", //
8384
"conv_relu_mkldnn_fuse_pass", //
8485
"conv_elementwise_add_mkldnn_fuse_pass", //

0 commit comments

Comments
 (0)