Skip to content

Commit 99dffb9

Browse files
committed
allow to repeatedly share and update BuildStrategy
test=develop
1 parent df826de commit 99dffb9

File tree

4 files changed

+20
-11
lines changed

4 files changed

+20
-11
lines changed

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
7979
BuildStrategy strategy_;
8080
};
8181

82-
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy()
83-
const {
82+
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
83+
bool from_user) const {
84+
if (finalized_by_user_) {
85+
return pass_builder_;
86+
}
8487
pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
88+
if (from_user) {
89+
finalized_by_user_ = true;
90+
}
8591
return pass_builder_;
8692
}
8793

@@ -95,10 +101,8 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
95101
#else
96102
const bool use_cuda) const {
97103
#endif
98-
// Create a default one if not initialized by user.
99-
if (!pass_builder_) {
100-
CreatePassesFromStrategy();
101-
}
104+
// Create a default one if not finalized by user.
105+
CreatePassesFromStrategy(false);
102106

103107
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
104108

paddle/fluid/framework/details/build_strategy.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ struct BuildStrategy {
8080
// from python side.
8181
// A new PassBuilder is created based on configs defined above and
8282
// passes are owned by the PassBuilder.
83-
std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy() const;
83+
std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy(
84+
bool from_user) const;
8485

8586
// Apply the passes built by the pass_builder_. The passes will be
8687
// applied to the Program and output an ir::Graph.
@@ -97,6 +98,7 @@ struct BuildStrategy {
9798
#endif
9899

99100
private:
101+
mutable bool finalized_by_user_ = false;
100102
mutable std::shared_ptr<ir::PassBuilder> pass_builder_;
101103
};
102104

paddle/fluid/pybind/pybind.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -855,10 +855,13 @@ All parameter, weight, gradient are variables in Paddle.
855855
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
856856
to fuse elementwise_add_op and activation_op,
857857
it may make the execution faster. Default False)DOC")
858-
.def("_create_passes_from_strategy",
858+
.def("_finalize_strategy_and_create_passes",
859859
[](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
860-
return self.CreatePassesFromStrategy();
861-
});
860+
return self.CreatePassesFromStrategy(true);
861+
},
862+
R"DOC(Allow user to customized passes. Normally model-specific
863+
optimization passes should be defined in this way. BuildStrategy
864+
cannot be updated after being finalized.)DOC");
862865

863866
pe.def(py::init<const std::vector<platform::Place> &,
864867
const std::unordered_set<std::string> &,

python/paddle/fluid/tests/unittests/test_pass_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def check_network_convergence(self, use_cuda, build_strategy=None):
9494

9595
def test_parallel_testing_with_new_strategy(self):
9696
build_strategy = fluid.BuildStrategy()
97-
pass_builder = build_strategy._create_passes_from_strategy()
97+
pass_builder = build_strategy._finalize_strategy_and_create_passes()
9898
origin_len = len(pass_builder.all_passes())
9999

100100
viz_pass = pass_builder.append_pass("graph_viz_pass")

0 commit comments

Comments
 (0)