Skip to content

Commit 66bb734

Browse files
authored
Merge pull request #12101 from panyx0718/ir
Initial IR change
2 parents 3694fd5 + 950585f commit 66bb734

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1049
-388
lines changed

doc/fluid/design/ir/draft.md

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
## Motivation
2+
3+
There is a ```gap``` between the ```Program``` defined by
4+
user and the ```Executable``` that can be scheduled
5+
efficiently on heterogeneous hardware, either locally
6+
or distributedly.
7+
8+
Usually, the ```gap``` is bridged by
9+
10+
* A serious transformations with defined order.
11+
12+
* These transformations usually involve
13+
```insert, delete, clustering, split, dependency analysis```.
14+
15+
* Has a simple way to verify and debug each transformation.
16+
17+
* Flexible to add, remove or customize transformations to fit
18+
the requirements of various algorithms (models) and hardware secenarios.
19+
20+
Some other events also push us to a better unified pattern.
21+
22+
* The deep learning framework is built around the concepts of graphs.
23+
To leverage tools such as compilation (e.g. TVM and nGraph) or
24+
cross-framework conversion (e.g. ONNX), we also need a intermediate
25+
representation that can be connected to the rest of the ecosystem.
26+
27+
28+
We need a unified pattern to naturally support the requirements
29+
described above. The pattern should fit both training, inference
30+
and other offline serielized model transformations.
31+
Learned from LLVM and other deep learning framework, we draft the
32+
design below.
33+
34+
35+
## Design
36+
37+
### Major Concepts
38+
39+
#### Node
40+
41+
```Node``` represents an operation that performs some computation or
42+
a variable that is input or output of operation.
43+
44+
```Node```s are connected to other ```Node```s via inputs and outputs.
45+
46+
Other properties (maybe device placement information) can be added
47+
to ```Node``` in the future if it's a
48+
common requirement of many other ```Pass```es. Otherwise, it should live
49+
in a ```Node``` wrapper class that is private to some ```Pass``` or be
50+
a local member of a ```Pass```.
51+
52+
#### Graph
53+
54+
```Graph``` contains a list of ```Node```s, which are connected to
55+
each other via inputs and outputs.
56+
57+
TODO: Better definitions for the graph.
58+
59+
```Graph``` can also contain ```Attribute```s. ```Attribute```s
60+
can be ``any`` thing. For example, it can be a list of "wraper"
61+
nodes. The ```wrapper``` nodes compose ```Node```s and provide
62+
helper method for execution or transformation. ```Attribute```
63+
can also contain other things that describe some properties of
64+
the ```Graph``` or ```Graph``` nodes. ```Attribute``` can be passed
65+
across ```Pass```. However, it should be used with care.
66+
67+
#### Pass
68+
69+
```Pass``` represents a transformation of ```Graph```. Its input
70+
is a ```Graph``` and its output is also a ```Graph```. For example,
71+
a ```Pass``` can simply print out the ```Graph```. A ```Pass```
72+
can also fuse some ```Graph```'s ```Node```s.
73+
74+
#### Optimize
75+
76+
```Optimize``` contains a series of ```Pass``` with defined order.
77+
```Optimize``` transforms a ```Graph``` that only contains raw
78+
modeling logic to a ```Graph``` that can be run efficiently while
79+
maintaining the original modeling logic.
80+
81+
82+
### Optimize Process
83+
84+
* Program is first converted to Graph.
85+
* Graph goes through a series of Pass
86+
* Graph is transformed from raw model logic to a
87+
form that is efficient to execute.
88+
89+
Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor

paddle/fluid/framework/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_subdirectory(details)
2+
add_subdirectory(ir)
23
# ddim lib
34
proto_library(framework_proto SRCS framework.proto)
45

@@ -93,7 +94,7 @@ else()
9394
endif()
9495

9596

96-
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
97+
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph)
9798

9899
cc_library(prune SRCS prune.cc DEPS framework_proto)
99100
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
55
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
66
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
77

8-
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
9-
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
8+
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph)
109
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
1110
cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder)
1211

@@ -35,7 +34,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS
3534

3635
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
3736

38-
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
37+
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
3938
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
4039
simple_threadpool device_context)
4140

paddle/fluid/framework/details/all_reduce_op_handle.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,25 @@ namespace framework {
2323
namespace details {
2424

2525
#ifdef PADDLE_WITH_CUDA
26-
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
26+
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
27+
const std::vector<Scope *> &local_scopes,
2728
const std::vector<platform::Place> &places,
2829
const platform::NCCLContextMap *ctxs)
29-
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
30+
: OpHandleBase(node),
31+
local_scopes_(local_scopes),
32+
places_(places),
33+
nccl_ctxs_(ctxs) {
3034
if (nccl_ctxs_) {
3135
for (auto &p : places_) {
3236
this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p);
3337
}
3438
}
3539
}
3640
#else
37-
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
41+
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
42+
const std::vector<Scope *> &local_scopes,
3843
const std::vector<platform::Place> &places)
39-
: local_scopes_(local_scopes), places_(places) {}
44+
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
4045
#endif
4146

4247
void AllReduceOpHandle::RunImpl() {

paddle/fluid/framework/details/all_reduce_op_handle.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ namespace details {
3030

3131
struct AllReduceOpHandle : public OpHandleBase {
3232
#ifdef PADDLE_WITH_CUDA
33-
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
33+
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
3434
const std::vector<platform::Place> &places,
3535
const platform::NCCLContextMap *ctxs);
3636
#else
37-
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
37+
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
3838
const std::vector<platform::Place> &places);
3939
#endif
4040
std::string Name() const override;

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,23 @@ namespace details {
3535
struct BroadcastOpHandle : public OpHandleBase {
3636
public:
3737
#ifdef PADDLE_WITH_CUDA
38-
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
38+
BroadcastOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
3939
const std::vector<platform::Place> &places,
4040
const platform::NCCLContextMap *nccl_ctxs)
41-
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) {
41+
: OpHandleBase(node),
42+
local_scopes_(local_scopes),
43+
places_(places),
44+
nccl_ctxs_(nccl_ctxs) {
4245
if (nccl_ctxs_) {
4346
for (auto &p_ctx : nccl_ctxs_->contexts_) {
4447
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
4548
}
4649
}
4750
}
4851
#else
49-
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
52+
BroadcastOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
5053
const std::vector<platform::Place> &places)
51-
: local_scopes_(local_scopes), places_(places) {}
54+
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
5255
#endif
5356

5457
std::string Name() const override;

paddle/fluid/framework/details/broadcast_op_handle_test.cc

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,48 +96,61 @@ struct TestBroadcastOpHandle {
9696
}
9797
param_scopes_[input_scope_idx]->Var("input");
9898

99+
std::unique_ptr<ir::Node> n(
100+
new ir::Node("node0", ir::Node::Type::kOperation));
99101
if (use_gpu_) {
100102
#ifdef PADDLE_WITH_CUDA
101-
op_handle_.reset(
102-
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
103+
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
104+
nccl_ctxs_.get()));
103105
#else
104106
PADDLE_THROW("CUDA is not support.");
105107
#endif
106108
} else {
107109
#ifdef PADDLE_WITH_CUDA
108-
op_handle_.reset(
109-
new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get()));
110+
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
111+
nccl_ctxs_.get()));
110112
#else
111-
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
113+
op_handle_.reset(
114+
new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_));
112115
#endif
113116
}
114117

115-
auto* in_var_handle =
116-
new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]);
118+
std::unique_ptr<ir::Node> v(
119+
new ir::Node("node1", ir::Node::Type::kVariable));
120+
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
121+
gpu_list_[input_scope_idx]);
117122
vars_.emplace_back(in_var_handle);
118123
op_handle_->AddInput(in_var_handle);
119124

120125
// add dummy var
121-
vars_.emplace_back(new DummyVarHandle());
126+
127+
std::unique_ptr<ir::Node> v2(
128+
new ir::Node("node2", ir::Node::Type::kVariable));
129+
vars_.emplace_back(new DummyVarHandle(v2.get()));
122130
DummyVarHandle* dummy_var_handle =
123131
static_cast<DummyVarHandle*>(vars_.back().get());
124-
dummy_var_handle->generated_op_ = nullptr;
132+
dummy_var_handle->ClearGeneratedOp();
125133
op_handle_->AddInput(dummy_var_handle);
126134

127135
for (size_t j = 0; j < gpu_list_.size(); ++j) {
128136
if (!use_gpu_) {
129137
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
130138
}
131-
VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]);
139+
std::unique_ptr<ir::Node> v3(
140+
new ir::Node("node3", ir::Node::Type::kVariable));
141+
VarHandle* out_var_handle =
142+
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
132143
vars_.emplace_back(out_var_handle);
133144
op_handle_->AddOutput(out_var_handle);
134145
}
135146

136147
// add dummy var
137-
vars_.emplace_back(new DummyVarHandle());
148+
std::unique_ptr<ir::Node> v4(
149+
new ir::Node("node4", ir::Node::Type::kVariable));
150+
vars_.emplace_back(new DummyVarHandle(v4.get()));
138151
DummyVarHandle* out_dummy_var_handle =
139152
static_cast<DummyVarHandle*>(vars_.back().get());
140-
out_dummy_var_handle->generated_op_ = nullptr;
153+
out_dummy_var_handle->ClearGeneratedOp();
141154
op_handle_->AddOutput(out_dummy_var_handle);
142155
}
143156

paddle/fluid/framework/details/computation_op_handle.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
namespace paddle {
2020
namespace framework {
2121
namespace details {
22-
ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
22+
ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
2323
platform::Place place)
24-
: op_(framework::OpRegistry::CreateOp(op_desc)),
24+
: OpHandleBase(node),
25+
op_(framework::OpRegistry::CreateOp(*node->Op())),
2526
scope_(scope),
2627
place_(place) {}
2728

@@ -35,8 +36,8 @@ void ComputationOpHandle::RunImpl() {
3536

3637
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
3738
bool need_wait =
38-
in_var && in_var->generated_op_ &&
39-
in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_];
39+
in_var && in_var->GeneratedOp() &&
40+
in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_[place_];
4041
return need_wait;
4142
}
4243

paddle/fluid/framework/details/computation_op_handle.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ namespace framework {
2828
namespace details {
2929
struct ComputationOpHandle : public OpHandleBase {
3030
public:
31-
ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
32-
platform::Place place);
31+
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place);
3332

3433
std::string Name() const override;
3534

paddle/fluid/framework/details/data_balance_op_handle.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ namespace details {
2222

2323
#ifdef PADDLE_WITH_CUDA
2424
DataBalanceOpHandle::DataBalanceOpHandle(
25-
const std::vector<Scope *> &local_scopes,
25+
ir::Node *node, const std::vector<Scope *> &local_scopes,
2626
const std::vector<platform::Place> &places,
2727
const platform::NCCLContextMap *ctxs)
28-
: local_scopes_(local_scopes), places_(places) {
28+
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {
2929
if (ctxs) {
3030
for (auto &p : places_) {
3131
this->dev_ctxes_[p] = ctxs->DevCtx(p);
@@ -34,9 +34,9 @@ DataBalanceOpHandle::DataBalanceOpHandle(
3434
}
3535
#else
3636
DataBalanceOpHandle::DataBalanceOpHandle(
37-
const std::vector<Scope *> &local_scopes,
37+
ir::Node *node, const std::vector<Scope *> &local_scopes,
3838
const std::vector<platform::Place> &places)
39-
: local_scopes_(local_scopes), places_(places) {}
39+
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
4040
#endif
4141

4242
std::string DataBalanceOpHandle::Name() const { return "data balance"; }

0 commit comments

Comments
 (0)