Skip to content

Commit aa1085d

Browse files
committed
all passes
add doc
1 parent e4d7d7a commit aa1085d

File tree

10 files changed

+87
-154
lines changed

10 files changed

+87
-154
lines changed

doc/fluid/design/ir/draft.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,44 @@ is a `Graph` and its output is also a `Graph`. For example,
7171
a `Pass` can simply print out the `Graph`. A `Pass`
7272
can also fuse some `Graph`'s `Node`s.
7373

74+
```cpp
75+
class Pass {
76+
public:
77+
78+
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
79+
80+
// Get a reference to the attributed previously set.
81+
template <typename AttrType>
82+
AttrType &Get(const std::string &attr_name) const;
83+
84+
// Set a pointer to the attribute. Pass takes ownership of the attribute.
85+
template <typename AttrType>
86+
void Set(const std::string &attr_name, AttrType *attr) ;
87+
88+
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
89+
// should delete the attribute.
90+
template <typename AttrType>
91+
void SetNotOwned(const std::string &attr_name, AttrType *attr);
92+
};
93+
94+
// In my_pass.cc
95+
class MyPass : public Pass {
96+
public:
97+
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
98+
// do something.
99+
return graph;
100+
}
101+
}
102+
REGISTER_PASS(my_pass, MyPass);
103+
104+
105+
// To use the pass.
106+
auto my_pass = ir::PassRegistry::Instance().Get("my_pass");
107+
graph = my_pass->Apply(std::move(graph));
108+
// Note: to force link my_pass.cc, in the code:
109+
USE_PASS(my_pass);
110+
```
111+
74112
#### Optimize
75113
76114
`Optimize` contains a series of `Pass` with defined order.

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ else()
9999
endif()
100100

101101

102-
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass)
102+
cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
103103

104104
cc_library(prune SRCS prune.cc DEPS framework_proto)
105105
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
3131
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
3232
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
3333

34-
35-
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
36-
3734
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
3835
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
3936
simple_threadpool device_context)

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ namespace framework {
3535
namespace details {
3636

3737
void MultiDevSSAGraphBuilder::Init() const {
38-
loss_var_name_ = Get<std::string>("loss_var_name");
39-
places_ = Get<std::vector<platform::Place>>("places");
40-
local_scopes_ = Get<std::vector<Scope *>>("local_scopes");
41-
strategy_ = Get<BuildStrategy>("strategy");
38+
loss_var_name_ = Get<const std::string>("loss_var_name");
39+
places_ = Get<const std::vector<platform::Place>>("places");
40+
local_scopes_ = Get<const std::vector<Scope *>>("local_scopes");
41+
strategy_ = Get<const BuildStrategy>("strategy");
4242
#ifdef PADDLE_WITH_CUDA
4343
nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs");
4444
#endif
4545

46-
for (auto &p : Get<std::unordered_set<std::string>>("params")) {
46+
for (auto &p : Get<const std::unordered_set<std::string>>("params")) {
4747
grad_names_.insert(GradVarName(p));
4848
}
4949
balance_vars_.resize(places_.size(), 0);

paddle/fluid/framework/details/ssa_graph_builder_factory.cc

Lines changed: 0 additions & 53 deletions
This file was deleted.

paddle/fluid/framework/details/ssa_graph_builder_factory.h

Lines changed: 0 additions & 71 deletions
This file was deleted.

paddle/fluid/framework/details/ssa_graph_checker.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
2626
public:
2727
std::unique_ptr<ir::Graph> Apply(
2828
std::unique_ptr<ir::Graph> graph) const override {
29-
auto new_graph = Get<ir::Pass>("previous_pass").Apply(std::move(graph));
30-
PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
31-
return new_graph;
29+
PADDLE_ENFORCE(IsValidGraph(graph.get()));
30+
return graph;
3231
}
3332

3433
bool IsValidGraph(const ir::Graph* graph) const;

paddle/fluid/framework/details/ssa_graph_printer.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,11 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
3939
public:
4040
std::unique_ptr<ir::Graph> Apply(
4141
std::unique_ptr<ir::Graph> graph) const override {
42-
auto new_graph = Get<ir::Pass>("previous_pass").Apply(std::move(graph));
43-
4442
std::unique_ptr<std::ostream> fout(
45-
new std::ofstream(Get<std::string>("debug_graphviz_path")));
43+
new std::ofstream(Get<const std::string>("debug_graphviz_path")));
4644
PADDLE_ENFORCE(fout->good());
47-
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*new_graph, *fout);
48-
return new_graph;
45+
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
46+
return graph;
4947
}
5048
};
5149

paddle/fluid/framework/ir/pass.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@ class Pass {
4242

4343
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
4444

45+
// Get a reference to the attributed previously set.
4546
template <typename AttrType>
4647
AttrType &Get(const std::string &attr_name) const {
4748
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(),
4849
"%s attr not registered for pass.", attr_name);
4950
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
5051
}
5152

53+
// Set a pointer to the attribute. Pass takes ownership of the attribute.
5254
template <typename AttrType>
5355
void Set(const std::string &attr_name, AttrType *attr) {
5456
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
@@ -59,6 +61,8 @@ class Pass {
5961
};
6062
}
6163

64+
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
65+
// should delete the attribute.
6266
template <typename AttrType>
6367
void SetNotOwned(const std::string &attr_name, AttrType *attr) {
6468
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
@@ -127,6 +131,7 @@ struct PassRegistrar : public Registrar {
127131
__test_global_namespace_##uniq_name##__>::value, \
128132
msg)
129133

134+
// Register a new pass that can be applied on the IR.
130135
#define REGISTER_PASS(pass_type, pass_class) \
131136
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
132137
__reg_pass__##pass_type, \

paddle/fluid/framework/parallel_executor.cc

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ limitations under the License. */
2626
#endif
2727

2828
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
29-
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
29+
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
30+
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
3031
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
3132
#include "paddle/fluid/platform/profiler.h"
3233

@@ -43,16 +44,6 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
4344
#else
4445
const BuildStrategy &strategy) {
4546
#endif
46-
details::ParallelExecutorPassManager builder_factory(
47-
places, loss_var_name, param_names, local_scopes, strategy);
48-
if (use_cuda) {
49-
#ifdef PADDLE_WITH_CUDA
50-
builder_factory.SetNCCLContextMap(nccl_ctxs);
51-
#else
52-
PADDLE_THROW("Not compiled with CUDA.");
53-
#endif
54-
}
55-
5647
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
5748
if (!strategy.debug_graphviz_path_.empty()) {
5849
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
@@ -62,8 +53,37 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
6253
graph = viz_pass->Apply(std::move(graph));
6354
}
6455

65-
auto builder = builder_factory.Create();
66-
graph = builder->Apply(std::move(graph));
56+
auto multi_device_pass =
57+
ir::PassRegistry::Instance().Get("multi_device_pass");
58+
multi_device_pass->SetNotOwned<const std::vector<platform::Place>>("places",
59+
&places);
60+
multi_device_pass->SetNotOwned<const std::string>("loss_var_name",
61+
&loss_var_name);
62+
multi_device_pass->SetNotOwned<const std::unordered_set<std::string>>(
63+
"params", &param_names);
64+
multi_device_pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
65+
&local_scopes);
66+
multi_device_pass->SetNotOwned<const BuildStrategy>("strategy", &strategy);
67+
68+
#ifdef PADDLE_WITH_CUDA
69+
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
70+
multi_device_pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
71+
#endif
72+
graph = multi_device_pass->Apply(std::move(graph));
73+
74+
if (!strategy.debug_graphviz_path_.empty()) {
75+
auto multi_device_print_pass =
76+
ir::PassRegistry::Instance().Get("multi_device_print_pass");
77+
multi_device_print_pass->SetNotOwned<const std::string>(
78+
"debug_graphviz_path", &strategy.debug_graphviz_path_);
79+
multi_device_print_pass->Set<details::GraphvizSSAGraphPrinter>(
80+
"graph_printer", new details::GraphvizSSAGraphPrinter);
81+
graph = multi_device_print_pass->Apply(std::move(graph));
82+
}
83+
84+
auto multi_device_check_pass =
85+
ir::PassRegistry::Instance().Get("multi_device_check_pass");
86+
graph = multi_device_check_pass->Apply(std::move(graph));
6787

6888
if (!strategy.debug_graphviz_path_.empty()) {
6989
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");

0 commit comments

Comments
 (0)