Skip to content

Commit 1c591c3

Browse files
authored
Merge branch 'develop' into fix_rpn_target_assign_op
2 parents f06c619 + 96e9b65 commit 1c591c3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1978
-553
lines changed

cmake/generic.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,13 @@ function(cc_library TARGET_NAME)
261261
add_dependencies(${TARGET_NAME} mklml)
262262
target_link_libraries(${TARGET_NAME} "-L${MKLML_LIB_DIR} -liomp5 -Wl,--as-needed")
263263
endif()
264+
# remove link to python, see notes at:
265+
# https://github.com/pybind/pybind11/blob/master/docs/compiling.rst#building-manually
266+
if("${cc_library_DEPS};" MATCHES "python;")
267+
list(REMOVE_ITEM cc_library_DEPS python)
268+
add_dependencies(${TARGET_NAME} python)
269+
target_link_libraries(${TARGET_NAME} "-Wl,-undefined,dynamic_lookup")
270+
endif()
264271
target_link_libraries(${TARGET_NAME} ${cc_library_DEPS})
265272
add_dependencies(${TARGET_NAME} ${cc_library_DEPS})
266273
endif()

paddle/fluid/framework/details/var_handle.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ struct VarHandleBase {
4949

5050
void AddOutput(OpHandleBase* out, ir::Node* node) {
5151
if (pending_ops_.find(out) == pending_ops_.end()) {
52+
PADDLE_ENFORCE(out != nullptr, "The output of %s should not be nullptr",
53+
this->Node()->Name());
5254
pending_ops_.insert(out);
5355
node_->outputs.push_back(node);
5456
}

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,17 @@ pass_library(embedding_fc_lstm_fuse_pass inference)
3737
pass_library(fc_gru_fuse_pass inference)
3838
pass_library(seq_concat_fc_fuse_pass inference)
3939
pass_library(conv_bn_fuse_pass inference)
40+
pass_library(seqconv_eltadd_relu_fuse_pass inference)
4041
if(WITH_MKLDNN)
4142
pass_library(mkldnn_placement_pass base)
4243
pass_library(conv_bias_mkldnn_fuse_pass inference)
4344
pass_library(conv_relu_mkldnn_fuse_pass inference)
4445
endif()
4546

4647
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
48+
if(WITH_MKLDNN)
49+
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference)
50+
endif()
4751

4852
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
4953

@@ -57,4 +61,5 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g
5761
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
5862
if (WITH_MKLDNN)
5963
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
64+
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
6065
endif ()
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
16+
#include <functional>
17+
#include <utility>
18+
19+
#include "paddle/fluid/framework/ir/graph_traits.h"
20+
21+
namespace paddle {
22+
namespace framework {
23+
namespace ir {
24+
namespace {
25+
26+
// The function keeps the graph consistent by replacing
27+
// a node 'from' in the set of inputs nodes
28+
// of the visited node by a node 'to'.
29+
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
30+
for (auto& node : GraphTraits::DFS(*graph)) {
31+
auto from_in_inputs =
32+
std::find(std::begin(node.inputs), std::end(node.inputs), from);
33+
34+
if (from_in_inputs != std::end(node.inputs)) {
35+
IR_NODE_LINK_TO(to, (&node));
36+
37+
auto inputs = node.Op()->Inputs();
38+
39+
using input_type = VariableNameMap::value_type;
40+
41+
std::for_each(std::begin(inputs), std::end(inputs),
42+
[from, to, &node](const input_type& i) -> void {
43+
auto param_names = i.second;
44+
auto pi = std::find(std::begin(param_names),
45+
std::end(param_names), from->Name());
46+
47+
if (pi != std::end(param_names)) {
48+
node.Op()->SetInput(i.first, {to->Name()});
49+
}
50+
});
51+
}
52+
}
53+
}
54+
} // namespace
55+
using graph_ptr = std::unique_ptr<ir::Graph>;
56+
57+
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
58+
FusePassBase::Init(name_scope_, graph.get());
59+
60+
GraphPatternDetector gpd;
61+
auto pattern = gpd.mutable_pattern();
62+
63+
patterns::Conv conv_pattern{pattern, name_scope_};
64+
auto conv_output = conv_pattern();
65+
66+
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
67+
elementwise_add_pattern(conv_output);
68+
69+
conv_output->AsIntermediate();
70+
71+
auto conv_op_has_bias = [](const Node& conv_op) -> std::pair<bool, Node*> {
72+
auto bias_input_names = conv_op.Op()->Inputs();
73+
auto bias_it = bias_input_names.find("Bias");
74+
75+
if (bias_it != std::end(bias_input_names)) {
76+
bool has_bias = !bias_it->second.empty();
77+
78+
if (has_bias) {
79+
auto conv_bias_names = bias_it->second;
80+
auto conv_bias_names_it =
81+
std::find_if(std::begin(conv_op.inputs), std::end(conv_op.inputs),
82+
[&conv_bias_names](Node* n) -> bool {
83+
return n->Name() == conv_bias_names[0];
84+
});
85+
return std::make_pair(has_bias, *conv_bias_names_it);
86+
}
87+
}
88+
89+
return std::make_pair(false, nullptr);
90+
};
91+
92+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
93+
Graph* g) {
94+
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
95+
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
96+
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
97+
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
98+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
99+
elementwise_add_pattern);
100+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
101+
elementwise_add_pattern);
102+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
103+
elementwise_add_pattern);
104+
105+
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return;
106+
107+
OpDesc op_desc;
108+
op_desc.SetType("conv2d");
109+
110+
op_desc.SetInput("Input", {conv_input->Name()});
111+
op_desc.SetInput("Filter", {conv_filter->Name()});
112+
op_desc.SetInput("ResidualData", {elementwise_add_x->Name()});
113+
op_desc.SetOutput("Output", {conv_output->Name()});
114+
115+
bool has_bias;
116+
Node* conv_bias;
117+
118+
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op);
119+
120+
if (has_bias) {
121+
op_desc.SetInput("Bias", {conv_bias->Name()});
122+
}
123+
124+
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
125+
op_desc.SetAttr(attr.first, attr.second);
126+
}
127+
128+
op_desc.SetAttr("fuse_residual_connection", true);
129+
130+
auto fused_conv_op = g->CreateOpNode(&op_desc);
131+
132+
IR_NODE_LINK_TO(conv_input, fused_conv_op);
133+
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
134+
IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op);
135+
IR_NODE_LINK_TO(fused_conv_op, conv_output);
136+
137+
if (has_bias) {
138+
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
139+
}
140+
141+
CorrectGraphEdges(g, elementwise_add_out, conv_output);
142+
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
143+
};
144+
145+
gpd(graph.get(), handler);
146+
147+
return graph;
148+
}
149+
} // namespace ir
150+
} // namespace framework
151+
} // namespace paddle
152+
153+
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
154+
paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass);
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
19+
#include "paddle/fluid/framework/ir/graph.h"
20+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
21+
22+
namespace paddle {
23+
namespace framework {
24+
namespace ir {
25+
26+
class ConvElementwiseAddMKLDNNFusePass : public FusePassBase {
27+
public:
28+
virtual ~ConvElementwiseAddMKLDNNFusePass() {}
29+
30+
protected:
31+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
32+
33+
const std::string name_scope_{"residual_connections_fuse_pass"};
34+
};
35+
36+
} // namespace ir
37+
} // namespace framework
38+
} // namespace paddle

0 commit comments

Comments
 (0)