Skip to content

Commit 7cd2761

Browse files
authored
Merge pull request #13416 from panyx0718/ir
PassBuilder
2 parents 43a3af8 + cbdf983 commit 7cd2761

File tree

15 files changed

+454
-109
lines changed

15 files changed

+454
-109
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,10 @@ else()
150150
endif()
151151

152152
if (NOT WIN32)
153-
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
154-
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
155-
graph graph_viz_pass multi_devices_graph_pass
156-
multi_devices_graph_print_pass multi_devices_graph_check_pass
157-
fast_threaded_ssa_graph_executor fuse_elewise_add_act_pass)
153+
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
154+
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
155+
graph build_strategy
156+
fast_threaded_ssa_graph_executor)
158157
endif() # NOT WIN32
159158

160159
cc_library(prune SRCS prune.cc DEPS framework_proto)

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,8 @@ cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_execu
5454
# device_context reduce_op_handle )
5555
cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc
5656
DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context)
57+
58+
cc_library(build_strategy SRCS build_strategy.cc DEPS
59+
graph_viz_pass multi_devices_graph_pass
60+
multi_devices_graph_print_pass multi_devices_graph_check_pass
61+
fuse_elewise_add_act_pass)
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/details/build_strategy.h"
16+
17+
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
18+
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
19+
#include "paddle/fluid/framework/ir/graph.h"
20+
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
21+
22+
namespace paddle {
23+
namespace framework {
24+
namespace details {
25+
26+
class ParallelExecutorPassBuilder : public ir::PassBuilder {
27+
public:
28+
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
29+
: ir::PassBuilder(), strategy_(strategy) {
30+
// Add a graph viz pass to record a graph.
31+
if (!strategy_.debug_graphviz_path_.empty()) {
32+
auto viz_pass = AppendPass("graph_viz_pass");
33+
const std::string graph_path = string::Sprintf(
34+
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph");
35+
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
36+
}
37+
38+
// Add op fusion.
39+
if (strategy.fuse_elewise_add_act_ops_) {
40+
auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass");
41+
// Add a graph viz pass to record a graph.
42+
if (!strategy.debug_graphviz_path_.empty()) {
43+
auto viz_pass = AppendPass("graph_viz_pass");
44+
const std::string graph_path = string::Sprintf(
45+
"%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph");
46+
viz_pass->Set<std::string>("graph_viz_path",
47+
new std::string(graph_path));
48+
}
49+
}
50+
51+
// Convert graph to run on multi-devices.
52+
auto multi_devices_pass = AppendPass("multi_devices_pass");
53+
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
54+
&strategy_);
55+
56+
// Add a graph print pass to record a graph with device info.
57+
if (!strategy_.debug_graphviz_path_.empty()) {
58+
auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
59+
multi_devices_print_pass->SetNotOwned<const std::string>(
60+
"debug_graphviz_path", &strategy_.debug_graphviz_path_);
61+
multi_devices_print_pass->Set<details::GraphvizSSAGraphPrinter>(
62+
"graph_printer", new details::GraphvizSSAGraphPrinter);
63+
}
64+
65+
// Verify that the graph is correct for multi-device executor.
66+
AppendPass("multi_devices_check_pass");
67+
}
68+
69+
private:
70+
BuildStrategy strategy_;
71+
};
72+
73+
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy()
74+
const {
75+
pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
76+
return pass_builder_;
77+
}
78+
79+
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
80+
const ProgramDesc &main_program, const std::vector<platform::Place> &places,
81+
const std::string &loss_var_name,
82+
const std::unordered_set<std::string> &param_names,
83+
const std::vector<Scope *> &local_scopes,
84+
#ifdef PADDLE_WITH_CUDA
85+
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
86+
#else
87+
const bool use_cuda) const {
88+
#endif
89+
// Create a default one if not initialized by user.
90+
if (!pass_builder_) {
91+
CreatePassesFromStrategy();
92+
}
93+
94+
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
95+
96+
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
97+
if (pass->Type() == "multi_devices_pass") {
98+
pass->Erase("places");
99+
pass->SetNotOwned<const std::vector<platform::Place>>("places", &places);
100+
pass->Erase("loss_var_name");
101+
pass->SetNotOwned<const std::string>("loss_var_name", &loss_var_name);
102+
pass->Erase("params");
103+
pass->SetNotOwned<const std::unordered_set<std::string>>("params",
104+
&param_names);
105+
pass->Erase("local_scopes");
106+
pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
107+
&local_scopes);
108+
#ifdef PADDLE_WITH_CUDA
109+
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
110+
pass->Erase("nccl_ctxs");
111+
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
112+
#endif
113+
}
114+
graph = pass->Apply(std::move(graph));
115+
}
116+
return graph;
117+
}
118+
} // namespace details
119+
} // namespace framework
120+
} // namespace paddle
121+
122+
USE_PASS(fuse_elewise_add_act_pass);
123+
USE_PASS(graph_viz_pass);
124+
USE_PASS(multi_devices_pass);
125+
USE_PASS(multi_devices_check_pass);
126+
USE_PASS(multi_devices_print_pass);

paddle/fluid/framework/details/build_strategy.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@
1515
#pragma once
1616

1717
#include <string>
18+
#include <vector>
19+
20+
#include "paddle/fluid/framework/ir/pass_builder.h"
21+
#include "paddle/fluid/framework/program_desc.h"
22+
#include "paddle/fluid/framework/scope.h"
23+
#include "paddle/fluid/platform/device_context.h"
24+
#include "paddle/fluid/platform/enforce.h"
25+
26+
#ifdef PADDLE_WITH_CUDA
27+
#include "paddle/fluid/platform/nccl_helper.h"
28+
#endif
1829

1930
namespace paddle {
2031
namespace framework {
@@ -57,6 +68,30 @@ struct BuildStrategy {
5768
bool fuse_elewise_add_act_ops_{false};
5869

5970
bool enable_data_balance_{false};
71+
72+
// User normally doesn't need to call this API.
73+
// The PassBuilder allows for more customized insert, remove of passes
74+
// from python side.
75+
// A new PassBuilder is created based on configs defined above and
76+
// passes are owned by the PassBuilder.
77+
std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy() const;
78+
79+
// Apply the passes built by the pass_builder_. The passes will be
80+
// applied to the Program and output an ir::Graph.
81+
std::unique_ptr<ir::Graph> Apply(
82+
const ProgramDesc &main_program,
83+
const std::vector<platform::Place> &places,
84+
const std::string &loss_var_name,
85+
const std::unordered_set<std::string> &param_names,
86+
const std::vector<Scope *> &local_scopes,
87+
#ifdef PADDLE_WITH_CUDA
88+
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const;
89+
#else
90+
const bool use_cuda) const;
91+
#endif
92+
93+
private:
94+
mutable std::shared_ptr<ir::PassBuilder> pass_builder_;
6095
};
6196

6297
} // namespace details

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass
4141

4242
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
4343

44+
cc_library(pass_builder SRCS pass_builder.cc DEPS pass)
45+
4446
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
4547
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
4648
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)

paddle/fluid/framework/ir/pass.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ namespace paddle {
1919
namespace framework {
2020
namespace ir {
2121
std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
22-
PADDLE_ENFORCE(!applied_, "Pass can only Apply() once.");
2322
PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty.");
2423
for (const std::string& attr : required_pass_attrs_) {
2524
PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(),

paddle/fluid/framework/ir/pass.h

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class Pass {
4242
attr_dels_.clear();
4343
}
4444

45+
std::string Type() const { return type_; }
46+
4547
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const;
4648

4749
// Get a reference to the attributed previously set.
@@ -52,6 +54,21 @@ class Pass {
5254
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
5355
}
5456

57+
bool Has(const std::string &attr_name) const {
58+
return attrs_.find(attr_name) != attrs_.end();
59+
}
60+
61+
void Erase(const std::string &attr_name) {
62+
if (!Has(attr_name)) {
63+
return;
64+
}
65+
if (attr_dels_.find(attr_name) != attr_dels_.end()) {
66+
attr_dels_[attr_name]();
67+
attr_dels_.erase(attr_name);
68+
}
69+
attrs_.erase(attr_name);
70+
}
71+
5572
// Set a pointer to the attribute. Pass takes ownership of the attribute.
5673
template <typename AttrType>
5774
void Set(const std::string &attr_name, AttrType *attr) {
@@ -68,13 +85,15 @@ class Pass {
6885
// should delete the attribute.
6986
template <typename AttrType>
7087
void SetNotOwned(const std::string &attr_name, AttrType *attr) {
71-
PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
88+
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass",
89+
attr_name);
7290
attrs_[attr_name] = attr;
7391
}
7492

7593
protected:
76-
virtual std::unique_ptr<Graph> ApplyImpl(
77-
std::unique_ptr<Graph> graph) const = 0;
94+
virtual std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const {
95+
LOG(FATAL) << "Calling virtual Pass not implemented.";
96+
}
7897

7998
private:
8099
template <typename PassType>
@@ -89,7 +108,10 @@ class Pass {
89108
required_graph_attrs_.insert(attrs.begin(), attrs.end());
90109
}
91110

111+
void RegisterType(const std::string &type) { type_ = type; }
112+
92113
mutable bool applied_{false};
114+
std::string type_;
93115
std::unordered_set<std::string> required_pass_attrs_;
94116
std::unordered_set<std::string> required_graph_attrs_;
95117
std::map<std::string, boost::any> attrs_;
@@ -143,10 +165,11 @@ struct PassRegistrar : public Registrar {
143165
PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type),
144166
"'%s' is registered more than once.", pass_type);
145167
PassRegistry::Instance().Insert(
146-
pass_type, [this]() -> std::unique_ptr<Pass> {
168+
pass_type, [this, pass_type]() -> std::unique_ptr<Pass> {
147169
std::unique_ptr<Pass> pass(new PassType());
148170
pass->RegisterRequiredPassAttrs(this->required_pass_attrs_);
149171
pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_);
172+
pass->RegisterType(pass_type);
150173
return pass;
151174
});
152175
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/ir/pass_builder.h"
16+
17+
namespace paddle {
18+
namespace framework {
19+
namespace ir {
20+
21+
std::shared_ptr<Pass> PassBuilder::AppendPass(const std::string& pass_type) {
22+
auto pass = ir::PassRegistry::Instance().Get(pass_type);
23+
passes_.emplace_back(pass.release());
24+
return passes_.back();
25+
}
26+
27+
void PassBuilder::RemovePass(size_t idx) {
28+
PADDLE_ENFORCE(passes_.size() > idx);
29+
passes_.erase(passes_.begin() + idx);
30+
}
31+
32+
std::shared_ptr<Pass> PassBuilder::InsertPass(size_t idx,
33+
const std::string& pass_type) {
34+
PADDLE_ENFORCE(passes_.size() >= idx);
35+
std::shared_ptr<Pass> pass(
36+
ir::PassRegistry::Instance().Get(pass_type).release());
37+
passes_.insert(passes_.begin() + idx, std::move(pass));
38+
return passes_[idx];
39+
}
40+
41+
} // namespace ir
42+
} // namespace framework
43+
} // namespace paddle
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include <vector>
19+
#include "paddle/fluid/framework/ir/pass.h"
20+
21+
namespace paddle {
22+
namespace framework {
23+
namespace ir {
24+
25+
class PassBuilder {
26+
public:
27+
PassBuilder() {}
28+
29+
virtual ~PassBuilder() {}
30+
31+
// Append a new pass to the end.
32+
std::shared_ptr<Pass> AppendPass(const std::string& pass_type);
33+
34+
// Insert a new pass after `idx`.
35+
std::shared_ptr<Pass> InsertPass(size_t idx, const std::string& pass_type);
36+
37+
// Remove a new pass at `idx`.
38+
void RemovePass(size_t idx);
39+
40+
// Returns a list of all passes.
41+
std::vector<std::shared_ptr<Pass>> AllPasses() const { return passes_; }
42+
43+
protected:
44+
std::vector<std::shared_ptr<Pass>> passes_;
45+
};
46+
47+
} // namespace ir
48+
} // namespace framework
49+
} // namespace paddle

0 commit comments

Comments
 (0)