Skip to content

Commit 2888d2d

Browse files
committed
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into parallel_bcast
2 parents 4778c6e + e69d9c8 commit 2888d2d

25 files changed

+398
-86
lines changed

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Our vision is to enable deep learning for everyone via PaddlePaddle.
1919
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle.
2020

2121

22-
### Latest PaddlePaddle Release: [Fluid 0.14.0](https://github.com/PaddlePaddle/Paddle/tree/v0.14.0)
22+
### Latest PaddlePaddle Release: [Fluid 0.15.0](https://github.com/PaddlePaddle/Paddle/tree/v0.15.0)
2323
### Install Latest Stable Release:
2424
```
2525
# Linux CPU
@@ -76,26 +76,26 @@ pip install paddlepaddle-gpu==0.14.0.post85
7676

7777
## Installation
7878

79-
It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/beginners_guide/install/install_doc.html) on our website.
79+
It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/beginners_guide/install/install_doc.html) on our website.
8080

8181
## Documentation
8282

83-
We provide [English](http://paddlepaddle.org/documentation/docs/en/0.14.0/getstarted/index_en.html) and
84-
[Chinese](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/beginners_guide/index.html) documentation.
83+
We provide [English](http://paddlepaddle.org/documentation/docs/en/0.15.0/getstarted/index_en.html) and
84+
[Chinese](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/beginners_guide/index.html) documentation.
8585

8686
- [Deep Learning 101](https://github.com/PaddlePaddle/book)
8787

8888
You might want to start from this online interactive book that can run in a Jupyter Notebook.
8989

90-
- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/user_guides/howto/training/cluster_howto.html)
90+
- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/user_guides/howto/training/cluster_howto.html)
9191

9292
You can run distributed training jobs on MPI clusters.
9393

94-
- [Python API](http://paddlepaddle.org/documentation/api/zh/0.14.0/fluid.html)
94+
- [Python API](http://paddlepaddle.org/documentation/api/zh/0.15.0/fluid.html)
9595

9696
Our new API enables much shorter programs.
9797

98-
- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/advanced_usage/development/contribute_to_paddle.html)
98+
- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/advanced_usage/development/contribute_to_paddle.html)
9999

100100
We appreciate your contributions!
101101

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph grap
2828
pass_library(graph_to_program_pass base)
2929
pass_library(graph_viz_pass base)
3030
pass_library(fc_fuse_pass inference)
31+
if(WITH_MKLDNN)
32+
pass_library(conv_relu_mkldnn_fuse_pass inference)
33+
endif()
3134
pass_library(attention_lstm_fuse_pass inference)
3235
pass_library(infer_clean_graph_pass inference)
3336
pass_library(fc_lstm_fuse_pass inference)
@@ -42,3 +45,6 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r
4245
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
4346
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
4447
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
48+
if(WITH_MKLDNN)
49+
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
50+
endif()
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h"
16+
#include <string>
17+
#include <vector>
18+
#include "paddle/fluid/platform/enforce.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace ir {
23+
24+
std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
25+
std::unique_ptr<ir::Graph> graph) const {
26+
PADDLE_ENFORCE(graph.get());
27+
FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get());
28+
29+
std::unordered_set<Node*> nodes2delete;
30+
31+
GraphPatternDetector gpd;
32+
auto* conv_input = gpd.mutable_pattern()
33+
->NewNode("conv_relu_mkldnn_fuse/conv_input")
34+
->AsInput()
35+
->assert_is_op_input("conv2d", "Input");
36+
patterns::ConvReLU conv_relu_pattern(gpd.mutable_pattern(),
37+
"conv_relu_mkldnn_fuse");
38+
conv_relu_pattern(conv_input);
39+
40+
int found_conv_relu_count = 0;
41+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
42+
Graph* g) {
43+
VLOG(4) << "handle ConvReLU fuse";
44+
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
45+
conv_relu_pattern); // Filter
46+
GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_relu_pattern); // Bias
47+
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp
48+
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op
49+
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out
50+
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op
51+
52+
// Create an ConvReLU Node.
53+
OpDesc desc;
54+
std::string conv_relu_i_in = subgraph.at(conv_input)->Name();
55+
std::string conv_relu_w_in = conv_weight->Name();
56+
std::string conv_relu_b_in = conv_bias->Name();
57+
std::string conv_relu_out = relu_out->Name();
58+
desc.SetInput("Input", std::vector<std::string>({conv_relu_i_in}));
59+
desc.SetInput("Filter", std::vector<std::string>({conv_relu_w_in}));
60+
desc.SetInput("Bias", std::vector<std::string>({conv_relu_b_in}));
61+
desc.SetOutput("Out", std::vector<std::string>({conv_relu_out}));
62+
desc.SetType("conv2d");
63+
for (auto& attr : conv->Op()->GetAttrMap()) {
64+
desc.SetAttr(attr.first, attr.second);
65+
}
66+
desc.SetAttr("fuse_relu", true);
67+
auto conv_relu_node = g->CreateOpNode(&desc); // OpDesc will be copied.
68+
GraphSafeRemoveNodes(graph.get(), {conv, relu, conv_out});
69+
70+
PADDLE_ENFORCE(subgraph.count(conv_input));
71+
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_relu_node);
72+
IR_NODE_LINK_TO(conv_weight, conv_relu_node);
73+
IR_NODE_LINK_TO(conv_bias, conv_relu_node);
74+
IR_NODE_LINK_TO(conv_relu_node, relu_out);
75+
76+
found_conv_relu_count++;
77+
};
78+
79+
gpd(graph.get(), handler);
80+
81+
AddStatis(found_conv_relu_count);
82+
return graph;
83+
}
84+
85+
} // namespace ir
86+
} // namespace framework
87+
} // namespace paddle
88+
89+
REGISTER_PASS(conv_relu_mkldnn_fuse_pass,
90+
paddle::framework::ir::ConvReLUFusePass);
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
18+
#include "paddle/fluid/framework/ir/graph.h"
19+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
20+
#include "paddle/fluid/framework/ir/pass.h"
21+
22+
namespace paddle {
23+
namespace framework {
24+
namespace ir {
25+
26+
/*
27+
* Fuse the CONV and ReLU to a ConvReLUOp.
28+
*/
29+
class ConvReLUFusePass : public FusePassBase {
30+
public:
31+
virtual ~ConvReLUFusePass() {}
32+
33+
protected:
34+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
35+
};
36+
37+
} // namespace ir
38+
} // namespace framework
39+
} // namespace paddle
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h"
16+
17+
#include <gtest/gtest.h>
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
void SetOp(ProgramDesc* prog, const std::string& type,
24+
const std::vector<std::string>& inputs,
25+
const std::vector<std::string>& outputs) {
26+
auto* op = prog->MutableBlock(0)->AppendOp();
27+
op->SetType(type);
28+
if (type == "conv2d") {
29+
op->SetAttr("use_mkldnn", true);
30+
op->SetInput("Input", {inputs[0]});
31+
op->SetInput("Filter", {inputs[1]});
32+
op->SetInput("Bias", {inputs[2]});
33+
} else if (type == "relu") {
34+
op->SetInput("X", inputs);
35+
}
36+
op->SetOutput("Out", outputs);
37+
}
38+
39+
// a->OP0->b
40+
// b->OP1->c
41+
// (c, weights, bias)->conv->f
42+
// (f)->relu->g
43+
ProgramDesc BuildProgramDesc() {
44+
ProgramDesc prog;
45+
for (auto& v :
46+
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g"})) {
47+
auto* var = prog.MutableBlock(0)->Var(v);
48+
var->SetType(proto::VarType::SELECTED_ROWS);
49+
if (v == "weights" || v == "bias") {
50+
var->SetPersistable(true);
51+
}
52+
}
53+
54+
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
55+
std::vector<std::string>({"b"}));
56+
SetOp(&prog, "OP1", std::vector<std::string>({"b"}),
57+
std::vector<std::string>({"c"}));
58+
SetOp(&prog, "conv2d", std::vector<std::string>({"c", "weights", "bias"}),
59+
std::vector<std::string>({"f"}));
60+
SetOp(&prog, "relu", std::vector<std::string>({"f"}),
61+
std::vector<std::string>({"g"}));
62+
63+
return prog;
64+
}
65+
66+
TEST(ConvReLUFusePass, basic) {
67+
auto prog = BuildProgramDesc();
68+
69+
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
70+
71+
auto pass = PassRegistry::Instance().Get("conv_relu_mkldnn_fuse_pass");
72+
73+
int original_nodes_num = graph->Nodes().size();
74+
75+
graph = pass->Apply(std::move(graph));
76+
77+
int current_nodes_num = graph->Nodes().size();
78+
79+
// Remove 3 Nodes: CONV, RELU, conv_out
80+
// Add 1 Node: ConvReLU
81+
EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
82+
83+
// Assert conv_relu op in newly generated graph
84+
int conv_relu_count = 0;
85+
86+
for (auto* node : graph->Nodes()) {
87+
if (node->IsOp() && node->Op()->Type() == "conv2d") {
88+
if (node->Op()->HasAttr("use_mkldnn")) {
89+
bool use_mkldnn = boost::get<bool>(node->Op()->GetAttr("use_mkldnn"));
90+
if (use_mkldnn) {
91+
if (node->Op()->HasAttr("fuse_relu")) {
92+
bool fuse_relu = boost::get<bool>(node->Op()->GetAttr("fuse_relu"));
93+
if (fuse_relu) {
94+
++conv_relu_count;
95+
}
96+
}
97+
}
98+
}
99+
}
100+
}
101+
EXPECT_EQ(conv_relu_count, 1);
102+
}
103+
104+
} // namespace ir
105+
} // namespace framework
106+
} // namespace paddle
107+
108+
USE_PASS(conv_relu_mkldnn_fuse_pass);

paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
5151
if (with_fc_bias) {
5252
// Add FC-bias with LSTM-bias and create a new weight
5353
PADDLE_ENFORCE(scope);
54-
const std::string& new_bias_var = name_scope + "_bias.new";
54+
const std::string& new_bias_var = patterns::UniqueKey("NewBias");
5555
auto* bias_var = scope->Var(new_bias_var);
5656
PADDLE_ENFORCE(bias_var);
5757
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
@@ -120,7 +120,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
120120

121121
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
122122
Graph* g) {
123-
124123
GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
125124
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
126125
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
@@ -136,7 +135,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
136135
fc_bias);
137136
// Remove unneeded nodes.
138137
std::unordered_set<const Node*> marked_nodes(
139-
{mul, lstm, elementwise_add});
138+
{mul, lstm, elementwise_add, fc_bias});
140139
GraphSafeRemoveNodes(graph, marked_nodes);
141140
} else {
142141
GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,39 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) {
522522
return false;
523523
}
524524

525+
PDNode* patterns::ConvReLU::operator()(
526+
paddle::framework::ir::PDNode* conv_input) {
527+
// Create Operators
528+
conv_input->assert_is_op_input("conv2d", "Input");
529+
auto* conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
530+
auto* relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu");
531+
// Create variables
532+
// Filter
533+
auto* conv_weight_var = pattern->NewNode(conv_weight_repr())
534+
->AsInput()
535+
->assert_is_persistable_var()
536+
->assert_is_op_input("conv2d", "Filter");
537+
// Bias
538+
auto* conv_bias_var = pattern->NewNode(conv_bias_repr())
539+
->AsInput()
540+
->assert_is_persistable_var()
541+
->assert_is_op_input("conv2d", "Bias");
542+
// intermediate variable, will be removed in the IR after fuse.
543+
auto* conv_out_var = pattern->NewNode(conv_out_repr())
544+
->AsIntermediate()
545+
->assert_is_only_output_of_op("conv2d")
546+
->assert_is_op_input("relu");
547+
// output
548+
auto* relu_out_var = pattern->NewNode(relu_out_repr())
549+
->AsOutput()
550+
->assert_is_op_output("relu");
551+
552+
conv_op->LinksFrom({conv_input, conv_weight_var, conv_bias_var})
553+
.LinksTo({conv_out_var});
554+
relu_op->LinksFrom({conv_out_var}).LinksTo({relu_out_var});
555+
return relu_out_var;
556+
}
557+
525558
PDNode* patterns::FC::operator()(paddle::framework::ir::PDNode* x,
526559
bool with_bias) {
527560
// Create shared nodes.

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,28 @@ struct PatternBase {
360360
size_t id_;
361361
};
362362

363+
// CONV with ReLU
364+
// op: conv + relu
365+
// named nodes:
366+
// conv_input, conv_weight,
367+
// conv_bias, conv_out, conv,
368+
// relu_out, relu
369+
struct ConvReLU : public PatternBase {
370+
ConvReLU(PDPattern* pattern, const std::string& name_scope)
371+
: PatternBase(pattern, name_scope, "conv_relu") {}
372+
373+
PDNode* operator()(PDNode* conv_input);
374+
375+
// declare operator node's name
376+
PATTERN_DECL_NODE(conv);
377+
PATTERN_DECL_NODE(relu);
378+
// declare variable node's name
379+
PATTERN_DECL_NODE(conv_weight);
380+
PATTERN_DECL_NODE(conv_bias);
381+
PATTERN_DECL_NODE(conv_out);
382+
PATTERN_DECL_NODE(relu_out);
383+
};
384+
363385
// FC with bias
364386
// op: mul + elementwise_add
365387
// named nodes:

0 commit comments

Comments
 (0)