Skip to content

Commit 8cfda7e

Browse files
authored
Merge pull request #14382 from panyx0718/fix4
Refine the pass builder and buildstrategy
2 parents 8f301f4 + bae3659 commit 8cfda7e

File tree

5 files changed

+43
-15
lines changed

5 files changed

+43
-15
lines changed

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
7979
BuildStrategy strategy_;
8080
};
8181

82-
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy()
83-
const {
82+
std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
83+
bool finalize_strategy) const {
84+
if (is_finalized_) {
85+
return pass_builder_;
86+
}
8487
pass_builder_.reset(new ParallelExecutorPassBuilder(*this));
88+
if (finalize_strategy) {
89+
is_finalized_ = true;
90+
}
8591
return pass_builder_;
8692
}
8793

@@ -95,10 +101,8 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
95101
#else
96102
const bool use_cuda) const {
97103
#endif
98-
// Create a default one if not initialized by user.
99-
if (!pass_builder_) {
100-
CreatePassesFromStrategy();
101-
}
104+
// Create a default one if not finalized by user.
105+
CreatePassesFromStrategy(false);
102106

103107
std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
104108

paddle/fluid/framework/details/build_strategy.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,20 @@ struct BuildStrategy {
7575

7676
bool remove_unnecessary_lock_{false};
7777

78+
// NOTE:
79+
// Before you add new options, think if it's a general strategy that works
80+
// with other strategy. If not, the strategy should be created through
81+
// CreatePassesFromStrategy and the pass can be managed separately.
82+
7883
// User normally doesn't need to call this API.
7984
// The PassBuilder allows for more customized insert, remove of passes
8085
// from python side.
8186
// A new PassBuilder is created based on configs defined above and
8287
// passes are owned by the PassBuilder.
83-
std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy() const;
88+
std::shared_ptr<ir::PassBuilder> CreatePassesFromStrategy(
89+
bool finalize_strategy) const;
90+
91+
bool IsFinalized() const { return is_finalized_; }
8492

8593
// Apply the passes built by the pass_builder_. The passes will be
8694
// applied to the Program and output an ir::Graph.
@@ -97,6 +105,7 @@ struct BuildStrategy {
97105
#endif
98106

99107
private:
108+
mutable bool is_finalized_ = false;
100109
mutable std::shared_ptr<ir::PassBuilder> pass_builder_;
101110
};
102111

paddle/fluid/pybind/pybind.cc

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -650,9 +650,9 @@ All parameter, weight, gradient are variables in Paddle.
650650
[](ir::Pass &self, const std::string &name, const std::string &attr) {
651651
self.Set<std::string>(name, new std::string(attr));
652652
})
653-
.def("set_int", [](ir::Pass &self, const std::string &name, int val) {
654-
self.Set<const int>(name, new int(val));
655-
});
653+
.def("set_int", [](ir::Pass &self, const std::string &name,
654+
int val) { self.Set<const int>(name, new int(val)); })
655+
.def("type", &ir::Pass::Type);
656656

657657
py::class_<ir::PassBuilder, std::shared_ptr<ir::PassBuilder>> pb(
658658
m, "PassBuilder");
@@ -791,6 +791,7 @@ All parameter, weight, gradient are variables in Paddle.
791791
"reduce_strategy",
792792
[](const BuildStrategy &self) { return self.reduce_; },
793793
[](BuildStrategy &self, BuildStrategy::ReduceStrategy strategy) {
794+
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
794795
self.reduce_ = strategy;
795796
},
796797
R"DOC(The type is STR, there are two reduce strategies in ParallelExecutor,
@@ -804,6 +805,7 @@ All parameter, weight, gradient are variables in Paddle.
804805
[](const BuildStrategy &self) { return self.gradient_scale_; },
805806
[](BuildStrategy &self,
806807
BuildStrategy::GradientScaleStrategy strategy) {
808+
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
807809
self.gradient_scale_ = strategy;
808810
},
809811
R"DOC(The type is STR, there are three ways of defining :math:`loss@grad` in
@@ -815,6 +817,7 @@ All parameter, weight, gradient are variables in Paddle.
815817
"debug_graphviz_path",
816818
[](const BuildStrategy &self) { return self.debug_graphviz_path_; },
817819
[](BuildStrategy &self, const std::string &path) {
820+
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
818821
self.debug_graphviz_path_ = path;
819822
},
820823
R"DOC(The type is STR, debug_graphviz_path indicate the path that
@@ -824,6 +827,7 @@ All parameter, weight, gradient are variables in Paddle.
824827
"enable_data_balance",
825828
[](const BuildStrategy &self) { return self.enable_data_balance_; },
826829
[](BuildStrategy &self, bool b) {
830+
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
827831
self.enable_data_balance_ = b;
828832
}) // FIXME(chengudo): enable_data_balance seems not important
829833
.def_property(
@@ -832,6 +836,7 @@ All parameter, weight, gradient are variables in Paddle.
832836
return self.enable_sequential_execution_;
833837
},
834838
[](BuildStrategy &self, bool b) {
839+
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
835840
self.enable_sequential_execution_ = b;
836841
},
837842
R"DOC(The type is BOOL. If set True, the execution order of ops would be the same as what is in the program. Default False.)DOC")
@@ -841,6 +846,7 @@ All parameter, weight, gradient are variables in Paddle.
841846
return self.remove_unnecessary_lock_;
842847
},
843848
[](BuildStrategy &self, bool b) {
849+
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
844850
self.remove_unnecessary_lock_ = b;
845851
},
846852
R"DOC(The type is BOOL. If set True, some locks in GPU ops would be released and ParallelExecutor would run faster. Default False.)DOC")
@@ -850,15 +856,19 @@ All parameter, weight, gradient are variables in Paddle.
850856
return self.fuse_elewise_add_act_ops_;
851857
},
852858
[](BuildStrategy &self, bool b) {
859+
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
853860
self.fuse_elewise_add_act_ops_ = b;
854861
},
855862
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
856863
to fuse elementwise_add_op and activation_op,
857864
it may make the execution faster. Default False)DOC")
858-
.def("_create_passes_from_strategy",
865+
.def("_finalize_strategy_and_create_passes",
859866
[](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
860-
return self.CreatePassesFromStrategy();
861-
});
867+
return self.CreatePassesFromStrategy(true);
868+
},
869+
R"DOC(Allow user to customized passes. Normally model-specific
870+
optimization passes should be defined in this way. BuildStrategy
871+
cannot be updated after being finalized.)DOC");
862872

863873
pe.def(py::init<const std::vector<platform::Place> &,
864874
const std::unordered_set<std::string> &,

python/paddle/fluid/tests/unittests/test_dist_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def run_trainer(self, args):
105105
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
106106

107107
if args.batch_merge_repeat > 1:
108-
pass_builder = build_stra._create_passes_from_strategy()
108+
pass_builder = build_stra._finalize_strategy_and_create_passes()
109109
mypass = pass_builder.insert_pass(
110110
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
111111
mypass.set_int("num_repeats", args.batch_merge_repeat)

python/paddle/fluid/tests/unittests/test_pass_builder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,12 @@ def check_network_convergence(self, use_cuda, build_strategy=None):
9494

9595
def test_parallel_testing_with_new_strategy(self):
9696
build_strategy = fluid.BuildStrategy()
97-
pass_builder = build_strategy._create_passes_from_strategy()
97+
self.assertFalse(build_strategy.fuse_elewise_add_act_ops)
98+
build_strategy.fuse_elewise_add_act_ops = True
99+
pass_builder = build_strategy._finalize_strategy_and_create_passes()
100+
self.assertTrue("fuse_elewise_add_act_pass" in
101+
[p.type() for p in pass_builder.all_passes()])
102+
98103
origin_len = len(pass_builder.all_passes())
99104

100105
viz_pass = pass_builder.append_pass("graph_viz_pass")

0 commit comments

Comments
 (0)