Skip to content

Commit a9d7a9d

Browse files
committed
test=develop
2 parents 5a38930 + 5d6783f commit a9d7a9d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+6053
-472
lines changed

cmake/generic.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,13 @@ function(cc_library TARGET_NAME)
261261
add_dependencies(${TARGET_NAME} mklml)
262262
target_link_libraries(${TARGET_NAME} "-L${MKLML_LIB_DIR} -liomp5 -Wl,--as-needed")
263263
endif()
264+
# remove link to python, see notes at:
265+
# https://github.com/pybind/pybind11/blob/master/docs/compiling.rst#building-manually
266+
if("${cc_library_DEPS};" MATCHES "python;")
267+
list(REMOVE_ITEM cc_library_DEPS python)
268+
add_dependencies(${TARGET_NAME} python)
269+
target_link_libraries(${TARGET_NAME} "-Wl,-undefined,dynamic_lookup")
270+
endif()
264271
target_link_libraries(${TARGET_NAME} ${cc_library_DEPS})
265272
add_dependencies(${TARGET_NAME} ${cc_library_DEPS})
266273
endif()

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ paddle.fluid.layers.pad ArgSpec(args=['x', 'paddings', 'pad_value', 'name'], var
116116
paddle.fluid.layers.pad_constant_like ArgSpec(args=['x', 'y', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0.0, None))
117117
paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 0.1, 'float32', None))
118118
paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0))
119+
paddle.fluid.layers.roi_align ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None))
119120
paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,))
120121
paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR'))
121122
paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',))

paddle/fluid/framework/details/var_handle.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ struct VarHandleBase {
4949

5050
void AddOutput(OpHandleBase* out, ir::Node* node) {
5151
if (pending_ops_.find(out) == pending_ops_.end()) {
52+
PADDLE_ENFORCE(out != nullptr, "The output of %s should not be nullptr",
53+
this->Node()->Name());
5254
pending_ops_.insert(out);
5355
node_->outputs.push_back(node);
5456
}

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,17 @@ pass_library(embedding_fc_lstm_fuse_pass inference)
3737
pass_library(fc_gru_fuse_pass inference)
3838
pass_library(seq_concat_fc_fuse_pass inference)
3939
pass_library(conv_bn_fuse_pass inference)
40+
pass_library(seqconv_eltadd_relu_fuse_pass inference)
4041
if(WITH_MKLDNN)
4142
pass_library(mkldnn_placement_pass base)
43+
pass_library(conv_bias_mkldnn_fuse_pass inference)
4244
pass_library(conv_relu_mkldnn_fuse_pass inference)
4345
endif()
4446

4547
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
48+
if(WITH_MKLDNN)
49+
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference)
50+
endif()
4651

4752
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
4853

@@ -56,4 +61,5 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g
5661
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
5762
if (WITH_MKLDNN)
5863
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
64+
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
5965
endif ()
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
16+
#include <functional>
17+
#include <string>
18+
#include <vector>
19+
#include "paddle/fluid/framework/lod_tensor.h"
20+
#include "paddle/fluid/platform/enforce.h"
21+
22+
namespace paddle {
23+
namespace framework {
24+
namespace ir {
25+
26+
template <typename BinaryOperation>
27+
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
28+
BinaryOperation f) {
29+
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims());
30+
LoDTensor vec_y;
31+
vec_y.Resize(vec_a.dims());
32+
const float* a = vec_a.data<float>();
33+
const float* b = vec_b.data<float>();
34+
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
35+
for (int i = 0; i < vec_a.numel(); i++) {
36+
y[i] = f(a[i], b[i]);
37+
}
38+
return vec_y;
39+
}
40+
41+
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
42+
std::unique_ptr<ir::Graph> graph) const {
43+
PADDLE_ENFORCE(graph.get());
44+
FusePassBase::Init(name_scope_, graph.get());
45+
46+
auto* scope = param_scope();
47+
PADDLE_ENFORCE(scope);
48+
49+
GraphPatternDetector gpd;
50+
auto* conv_input =
51+
gpd.mutable_pattern()
52+
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
53+
->AsInput()
54+
->assert_is_op_input("conv2d", "Input");
55+
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
56+
conv_bias_pattern(conv_input);
57+
int found_conv_bias_count = 0;
58+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
59+
Graph* g) {
60+
VLOG(4) << "handle ConvBias fuse";
61+
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
62+
conv_bias_pattern); // Filter
63+
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_bias_pattern); // tmp
64+
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_bias_pattern); // CONV op
65+
// bias
66+
GET_IR_NODE_FROM_SUBGRAPH(eltwise_bias, eltwise_bias, conv_bias_pattern);
67+
// output
68+
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bias_pattern);
69+
// elementwise_add op
70+
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern);
71+
72+
PADDLE_ENFORCE(subgraph.count(conv_input));
73+
74+
// check if fuse can be done and if MKL-DNN should be used
75+
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
76+
if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) {
77+
VLOG(3) << "do not perform conv+bias fuse";
78+
return;
79+
}
80+
81+
auto* eltwise_bias_tensor =
82+
scope->FindVar(eltwise_bias->Name())->GetMutable<LoDTensor>();
83+
84+
auto input_names = conv->Op()->InputNames();
85+
bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") !=
86+
input_names.end();
87+
if (has_bias && conv->Op()->Input("Bias").size() > 0) {
88+
auto conv_bias_names = conv->Op()->Input("Bias");
89+
// add eltwise bias to existing conv bias
90+
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1);
91+
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
92+
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
93+
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), eltwise_bias_tensor->dims());
94+
*conv_bias_tensor = tensor_apply_eltwise(
95+
*conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>());
96+
97+
conv->Op()->SetOutput("Output",
98+
std::vector<std::string>({eltwise_out->Name()}));
99+
100+
GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out});
101+
102+
IR_NODE_LINK_TO(conv, eltwise_out);
103+
} else {
104+
// take eltwise bias as conv bias
105+
OpDesc desc;
106+
107+
desc.SetInput(
108+
"Input", std::vector<std::string>({subgraph.at(conv_input)->Name()}));
109+
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
110+
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
111+
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
112+
desc.SetType("conv2d");
113+
114+
for (auto& attr : conv->Op()->GetAttrMap()) {
115+
desc.SetAttr(attr.first, attr.second);
116+
}
117+
auto conv_bias_node = g->CreateOpNode(&desc);
118+
119+
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node);
120+
IR_NODE_LINK_TO(conv_weight, conv_bias_node);
121+
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
122+
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
123+
124+
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
125+
}
126+
127+
found_conv_bias_count++;
128+
};
129+
gpd(graph.get(), handler);
130+
AddStatis(found_conv_bias_count);
131+
return graph;
132+
}
133+
} // namespace ir
134+
} // namespace framework
135+
} // namespace paddle
136+
REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
137+
paddle::framework::ir::ConvBiasFusePass);
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#pragma once
15+
#include <string>
16+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
17+
#include "paddle/fluid/framework/ir/graph.h"
18+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
19+
#include "paddle/fluid/framework/ir/pass.h"
20+
namespace paddle {
21+
namespace framework {
22+
namespace ir {
23+
/*
24+
* Fuse the Conv and Elementwise_add to a ConvBiasOp.
25+
*/
26+
class ConvBiasFusePass : public FusePassBase {
27+
public:
28+
virtual ~ConvBiasFusePass() {}
29+
30+
protected:
31+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
32+
const std::string name_scope_{"conv_bias_mkldnn_fuse"};
33+
};
34+
} // namespace ir
35+
} // namespace framework
36+
} // namespace paddle
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
16+
#include <functional>
17+
#include <utility>
18+
19+
#include "paddle/fluid/framework/ir/graph_traits.h"
20+
21+
namespace paddle {
22+
namespace framework {
23+
namespace ir {
24+
namespace {
25+
26+
// The function keeps the graph consistent by replacing
27+
// a node 'from' in the set of inputs nodes
28+
// of the visited node by a node 'to'.
29+
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
30+
for (auto& node : GraphTraits::DFS(*graph)) {
31+
auto from_in_inputs =
32+
std::find(std::begin(node.inputs), std::end(node.inputs), from);
33+
34+
if (from_in_inputs != std::end(node.inputs)) {
35+
IR_NODE_LINK_TO(to, (&node));
36+
37+
auto inputs = node.Op()->Inputs();
38+
39+
using input_type = VariableNameMap::value_type;
40+
41+
std::for_each(std::begin(inputs), std::end(inputs),
42+
[from, to, &node](const input_type& i) -> void {
43+
auto param_names = i.second;
44+
auto pi = std::find(std::begin(param_names),
45+
std::end(param_names), from->Name());
46+
47+
if (pi != std::end(param_names)) {
48+
node.Op()->SetInput(i.first, {to->Name()});
49+
}
50+
});
51+
}
52+
}
53+
}
54+
} // namespace
55+
using graph_ptr = std::unique_ptr<ir::Graph>;
56+
57+
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
58+
FusePassBase::Init(name_scope_, graph.get());
59+
60+
GraphPatternDetector gpd;
61+
auto pattern = gpd.mutable_pattern();
62+
63+
patterns::Conv conv_pattern{pattern, name_scope_};
64+
auto conv_output = conv_pattern();
65+
66+
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
67+
elementwise_add_pattern(conv_output);
68+
69+
conv_output->AsIntermediate();
70+
71+
auto conv_op_has_bias = [](const Node& conv_op) -> std::pair<bool, Node*> {
72+
auto bias_input_names = conv_op.Op()->Inputs();
73+
auto bias_it = bias_input_names.find("Bias");
74+
75+
if (bias_it != std::end(bias_input_names)) {
76+
bool has_bias = !bias_it->second.empty();
77+
78+
if (has_bias) {
79+
auto conv_bias_names = bias_it->second;
80+
auto conv_bias_names_it =
81+
std::find_if(std::begin(conv_op.inputs), std::end(conv_op.inputs),
82+
[&conv_bias_names](Node* n) -> bool {
83+
return n->Name() == conv_bias_names[0];
84+
});
85+
return std::make_pair(has_bias, *conv_bias_names_it);
86+
}
87+
}
88+
89+
return std::make_pair(false, nullptr);
90+
};
91+
92+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
93+
Graph* g) {
94+
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
95+
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
96+
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
97+
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
98+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
99+
elementwise_add_pattern);
100+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
101+
elementwise_add_pattern);
102+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
103+
elementwise_add_pattern);
104+
105+
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return;
106+
107+
OpDesc op_desc;
108+
op_desc.SetType("conv2d");
109+
110+
op_desc.SetInput("Input", {conv_input->Name()});
111+
op_desc.SetInput("Filter", {conv_filter->Name()});
112+
op_desc.SetInput("ResidualData", {elementwise_add_x->Name()});
113+
op_desc.SetOutput("Output", {conv_output->Name()});
114+
115+
bool has_bias;
116+
Node* conv_bias;
117+
118+
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op);
119+
120+
if (has_bias) {
121+
op_desc.SetInput("Bias", {conv_bias->Name()});
122+
}
123+
124+
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
125+
op_desc.SetAttr(attr.first, attr.second);
126+
}
127+
128+
op_desc.SetAttr("fuse_residual_connection", true);
129+
130+
auto fused_conv_op = g->CreateOpNode(&op_desc);
131+
132+
IR_NODE_LINK_TO(conv_input, fused_conv_op);
133+
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
134+
IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op);
135+
IR_NODE_LINK_TO(fused_conv_op, conv_output);
136+
137+
if (has_bias) {
138+
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
139+
}
140+
141+
CorrectGraphEdges(g, elementwise_add_out, conv_output);
142+
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
143+
};
144+
145+
gpd(graph.get(), handler);
146+
147+
return graph;
148+
}
149+
} // namespace ir
150+
} // namespace framework
151+
} // namespace paddle
152+
153+
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
154+
paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass);

0 commit comments

Comments
 (0)