Skip to content

Commit a5c96af

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_tensorrt_conv2d_converter
2 parents f05c7fb + baff71d commit a5c96af

File tree

75 files changed

+1849
-690
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+1849
-690
lines changed

doc/fluid/design/ir/draft.md

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,96 @@ can also contain other things that describe some properties of
6464
the `Graph` or `Graph` nodes. `Attribute` can be passed
6565
across `Pass`. However, it should be used with care.
6666

67+
```cpp
68+
class Graph {
69+
public:
70+
explicit Graph(const ProgramDesc &program);
71+
72+
bool Has(const std::string &attr_name) const;
73+
74+
template <typename AttrType>
75+
AttrType &Get(const std::string &attr_name) const;
76+
77+
template <typename AttrType>
78+
void Set(const std::string &attr_name, AttrType *attr);
79+
const std::unordered_set<ir::Node *> &Nodes() const;
80+
81+
// Create a normal variable with non-null VarDesc.
82+
ir::Node *CreateVarNode(VarDesc *var_desc);
83+
84+
// Create a normal runnable operator with OpDesc.
85+
ir::Node *CreateOpNode(OpDesc *op_desc);
86+
87+
// Create a control dependency var that connects 2 operations. The
88+
// var doesn't hold any data. Other than that, it's no different from
89+
// other var, considering dependency analysis.
90+
ir::Node *CreateControlDepVar();
91+
92+
// A more free style way of creating a graph node. Mostly use for test
93+
// or "copy" from another node. Avoid using it if possible.
94+
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type);
95+
96+
// Clear all node information of the graph and return the ownership of the
97+
// nodes.
98+
std::vector<std::unique_ptr<ir::Node>> ReleaseNodes();
99+
};
100+
```
101+
67102
#### Pass
68103
69104
`Pass` represents a transformation of `Graph`. Its input
70105
is a `Graph` and its output is also a `Graph`. For example,
71106
a `Pass` can simply print out the `Graph`. A `Pass`
72107
can also fuse some `Graph`'s `Node`s.
73108
109+
```cpp
110+
class Pass {
111+
public:
112+
113+
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const {
114+
// Some correctness check.
115+
auto new_graph = ApplyImpl(std::move(graph));
116+
// Some correctness check.
117+
return new_graph;
118+
}
119+
120+
// Get a reference to the attributed previously set.
121+
template <typename AttrType>
122+
AttrType &Get(const std::string &attr_name) const;
123+
124+
// Set a pointer to the attribute. Pass takes ownership of the attribute.
125+
template <typename AttrType>
126+
void Set(const std::string &attr_name, AttrType *attr) ;
127+
128+
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
129+
// should delete the attribute.
130+
template <typename AttrType>
131+
void SetNotOwned(const std::string &attr_name, AttrType *attr);
132+
133+
protected:
134+
virtual std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const = 0;
135+
};
136+
137+
// In my_pass.cc
138+
class MyPass : public Pass {
139+
protected:
140+
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override {
141+
// do something.
142+
return graph;
143+
}
144+
}
145+
REGISTER_PASS(my_pass, MyPass)
146+
.RequirePassAttr("places")
147+
.RequireGraphAttr("dep_vars");
148+
149+
150+
// To use the pass.
151+
auto my_pass = ir::PassRegistry::Instance().Get("my_pass");
152+
graph = my_pass->Apply(std::move(graph));
153+
// Note: to force link my_pass.cc, in the code:
154+
USE_PASS(my_pass);
155+
```
156+
74157
#### Optimize
75158

76159
`Optimize` contains a series of `Pass` with defined order.
@@ -86,4 +169,17 @@ maintaining the original modeling logic.
86169
* Graph is transformed from raw model logic to a
87170
form that is efficient to execute.
88171

89-
Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
172+
```
173+
// Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
174+
auto graph = Graph(program);
175+
graph = PassRegistry::Instance().Get("op_fuse_pass").Apply(std::move(grah));
176+
// For more complex Pass, Optimize Process can provide Pass attributes.
177+
auto mem_opt_pass = PassRegistry::Instance().Get("memory_optimization_pass");
178+
mem_opt_pass.SetNotOwned<int>("optimize_level", 1);
179+
mem_opt_pass->Apply(std::move(graph));
180+
graph = PassRegistry::Instance().Get("multi_device_pass").Apply(std::move(grah));
181+
graph = PassRegistry::Instance().Get("multi_device_check_pass").Apply(std::move(grah));
182+
Executor exe;
183+
exe.Run(graph);
184+
185+
```

paddle/fluid/API.spec

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], var
170170
paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
171171
paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
172172
paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
173+
paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))
173174
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
174175
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
175176
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
@@ -201,7 +202,6 @@ paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=
201202
paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None)
202203
paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'name'], varargs=None, keywords=None, defaults=(None,))
203204
paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
204-
paddle.fluid.layers.While.complete ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
205205
paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,))
206206
paddle.fluid.layers.Switch.case ArgSpec(args=['self', 'condition'], varargs=None, keywords=None, defaults=None)
207207
paddle.fluid.layers.Switch.default ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
@@ -225,17 +225,14 @@ paddle.fluid.layers.DynamicRNN.static_input ArgSpec(args=['self', 'x'], varargs=
225225
paddle.fluid.layers.DynamicRNN.step_input ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None)
226226
paddle.fluid.layers.DynamicRNN.update_memory ArgSpec(args=['self', 'ex_mem', 'new_mem'], varargs=None, keywords=None, defaults=None)
227227
paddle.fluid.layers.StaticRNN.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,))
228-
paddle.fluid.layers.StaticRNN.complete_op ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
229228
paddle.fluid.layers.StaticRNN.memory ArgSpec(args=['self', 'init', 'shape', 'batch_ref', 'init_value', 'init_batch_dim_idx', 'ref_batch_dim_idx'], varargs=None, keywords=None, defaults=(None, None, None, 0.0, 0, 1))
230229
paddle.fluid.layers.StaticRNN.output ArgSpec(args=['self'], varargs='outputs', keywords=None, defaults=None)
231-
paddle.fluid.layers.StaticRNN.parent_block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
232230
paddle.fluid.layers.StaticRNN.step ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
233231
paddle.fluid.layers.StaticRNN.step_input ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None)
234232
paddle.fluid.layers.StaticRNN.step_output ArgSpec(args=['self', 'o'], varargs=None, keywords=None, defaults=None)
235233
paddle.fluid.layers.StaticRNN.update_memory ArgSpec(args=['self', 'mem', 'var'], varargs=None, keywords=None, defaults=None)
236234
paddle.fluid.layers.reorder_lod_tensor_by_rank ArgSpec(args=['x', 'rank_table'], varargs=None, keywords=None, defaults=None)
237235
paddle.fluid.layers.ParallelDo.__init__ ArgSpec(args=['self', 'places', 'use_nccl', 'name'], varargs=None, keywords=None, defaults=(False, None))
238-
paddle.fluid.layers.ParallelDo.complete_op ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
239236
paddle.fluid.layers.ParallelDo.do ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
240237
paddle.fluid.layers.ParallelDo.get_parameters ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
241238
paddle.fluid.layers.ParallelDo.parent_block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ else()
9999
endif()
100100

101101

102-
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph)
102+
cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
103103

104104
cc_library(prune SRCS prune.cc DEPS framework_proto)
105105
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
3131
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
3232
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
3333

34-
35-
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
36-
3734
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
3835
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
3936
simple_threadpool device_context)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/platform/enforce.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace details {
22+
23+
class ExceptionHolder {
24+
public:
25+
void Catch(const platform::EnforceNotMet& exp) {
26+
std::lock_guard<std::mutex> lock(mu_);
27+
exception_.reset(new platform::EnforceNotMet(exp));
28+
type_ = kEnforceNotMet;
29+
}
30+
31+
void Catch(const platform::EOFException& exp) {
32+
std::lock_guard<std::mutex> lock(mu_);
33+
// EOFException will not cover up existing EnforceNotMet.
34+
if (exception_.get() == nullptr) {
35+
exception_.reset(new platform::EOFException(exp));
36+
type_ = kEOF;
37+
}
38+
}
39+
40+
bool ExceptionCatched() const {
41+
std::lock_guard<std::mutex> lock(mu_);
42+
return exception_.get() != nullptr;
43+
}
44+
45+
void Throw() {
46+
std::lock_guard<std::mutex> lock(mu_);
47+
switch (type_) {
48+
case kNone:
49+
break;
50+
case kEnforceNotMet: {
51+
auto e = *static_cast<platform::EnforceNotMet*>(exception_.get());
52+
throw e;
53+
break;
54+
}
55+
case kEOF: {
56+
auto e = *static_cast<platform::EOFException*>(exception_.get());
57+
throw e;
58+
break;
59+
}
60+
default:
61+
LOG(FATAL) << "Unknown exception.";
62+
}
63+
exception_.reset();
64+
type_ = kNone;
65+
}
66+
67+
void Clear() {
68+
std::lock_guard<std::mutex> lock(mu_);
69+
exception_.reset();
70+
type_ = kNone;
71+
}
72+
73+
private:
74+
enum ExceptionType { kNone, kEnforceNotMet, kEOF };
75+
ExceptionType type_{kNone};
76+
77+
std::unique_ptr<std::exception> exception_;
78+
mutable std::mutex mu_;
79+
};
80+
81+
} // namespace details
82+
} // namespace framework
83+
} // namespace paddle

0 commit comments

Comments
 (0)