Skip to content

Commit 142e832

Browse files
committed
pass registration
1 parent 5b18355 commit 142e832

10 files changed

+191
-104
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,30 +34,16 @@ namespace paddle {
3434
namespace framework {
3535
namespace details {
3636

37+
void MultiDevSSAGraphBuilder::Init() const {
38+
loss_var_name_ = Get<std::string>("loss_var_name");
39+
places_ = Get<std::vector<platform::Place>>("places");
40+
local_scopes_ = Get<std::vector<Scope *>>("local_scopes");
41+
strategy_ = Get<BuildStrategy>("strategy");
3742
#ifdef PADDLE_WITH_CUDA
38-
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
39-
const std::vector<platform::Place> &places,
40-
const std::string &loss_var_name,
41-
const std::unordered_set<std::string> &params,
42-
const std::vector<Scope *> &local_scopes,
43-
platform::NCCLContextMap *nccl_ctxs, const BuildStrategy &strategy)
44-
: loss_var_name_(loss_var_name),
45-
places_(places),
46-
local_scopes_(local_scopes),
47-
nccl_ctxs_(nccl_ctxs),
48-
strategy_(strategy) {
49-
#else
50-
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
51-
const std::vector<platform::Place> &places,
52-
const std::string &loss_var_name,
53-
const std::unordered_set<std::string> &params,
54-
const std::vector<Scope *> &local_scopes, const BuildStrategy &strategy)
55-
: loss_var_name_(loss_var_name),
56-
places_(places),
57-
local_scopes_(local_scopes),
58-
strategy_(strategy) {
43+
nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs");
5944
#endif
60-
for (auto &p : params) {
45+
46+
for (auto &p : Get<std::unordered_set<std::string>>("params")) {
6147
grad_names_.insert(GradVarName(p));
6248
}
6349
balance_vars_.resize(places_.size(), 0);
@@ -241,6 +227,7 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
241227

242228
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
243229
std::unique_ptr<ir::Graph> graph) const {
230+
Init();
244231
// Give the topology sort order and rebuild the graph structure.
245232
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
246233
auto nodes = graph->ReleaseNodes();

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,36 +32,23 @@ namespace details {
3232

3333
class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
3434
public:
35-
#ifdef PADDLE_WITH_CUDA
36-
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
37-
const std::string &loss_var_name,
38-
const std::unordered_set<std::string> &params,
39-
const std::vector<Scope *> &local_scopes,
40-
platform::NCCLContextMap *nccl_ctxs,
41-
const BuildStrategy &strategy);
42-
#else
43-
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
44-
const std::string &loss_var_name,
45-
const std::unordered_set<std::string> &params,
46-
const std::vector<Scope *> &local_scopes,
47-
const BuildStrategy &strategy);
48-
#endif
4935
std::unique_ptr<ir::Graph> Apply(
5036
std::unique_ptr<ir::Graph> graph) const override;
5137
int GetVarDeviceID(const std::string &varname) const override;
5238

5339
private:
5440
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
5541
size_t device_id) const;
42+
void Init() const;
5643

5744
private:
58-
std::string loss_var_name_;
59-
const std::vector<platform::Place> &places_;
60-
const std::vector<Scope *> &local_scopes_;
61-
std::unordered_set<std::string> grad_names_;
45+
mutable std::string loss_var_name_;
46+
mutable std::vector<platform::Place> places_;
47+
mutable std::vector<Scope *> local_scopes_;
48+
mutable std::unordered_set<std::string> grad_names_;
6249

6350
#ifdef PADDLE_WITH_CUDA
64-
platform::NCCLContextMap *nccl_ctxs_;
51+
mutable platform::NCCLContextMap *nccl_ctxs_;
6552
#endif
6653

6754
bool IsScaleLossOp(ir::Node *node) const;
@@ -113,7 +100,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
113100
const std::vector<std::string> &var_names) const;
114101

115102
private:
116-
BuildStrategy strategy_;
103+
mutable BuildStrategy strategy_;
117104
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
118105
mutable std::unordered_map<std::string, int> var_name_on_devices_;
119106
mutable std::vector<int64_t> balance_vars_;

paddle/fluid/framework/details/ssa_graph_builder_factory.cc

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,29 @@ namespace paddle {
2222
namespace framework {
2323
namespace details {
2424
std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
25-
std::unique_ptr<SSAGraphBuilder> res(
25+
std::unique_ptr<SSAGraphBuilder> res(new MultiDevSSAGraphBuilder);
26+
res->SetNotOwned<std::vector<platform::Place>>("places", &places_);
27+
res->SetNotOwned<std::string>("loss_var_name", &loss_var_name_);
28+
res->SetNotOwned<std::unordered_set<std::string>>("params", &param_names_);
29+
res->SetNotOwned<std::vector<Scope *>>("local_scopes", &local_scopes_);
30+
res->SetNotOwned<BuildStrategy>("strategy", &strategy_);
2631
#ifdef PADDLE_WITH_CUDA
27-
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
28-
local_scopes_, nccl_ctxs_, strategy_)
29-
#else
30-
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
31-
local_scopes_, strategy_)
32+
res->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nccl_ctxs_);
3233
#endif
33-
); // NOLINT
3434

3535
if (!strategy_.debug_graphviz_path_.empty()) {
36-
std::unique_ptr<std::ostream> fout(
37-
new std::ofstream(strategy_.debug_graphviz_path_));
38-
PADDLE_ENFORCE(fout->good());
39-
std::unique_ptr<GraphvizSSAGraphPrinter> graphviz_printer(
40-
new GraphvizSSAGraphPrinter());
41-
res.reset(new SSAGraghBuilderWithPrinter(
42-
std::move(fout), std::move(graphviz_printer), std::move(res)));
36+
SSAGraphBuilder *previous_pass = res.release();
37+
res.reset(new SSAGraghBuilderWithPrinter);
38+
res->Set<SSAGraphBuilder>("previous_pass", previous_pass);
39+
res->SetNotOwned<std::string>("debug_graphviz_path",
40+
&strategy_.debug_graphviz_path_);
41+
res->Set<GraphvizSSAGraphPrinter>("graph_printer",
42+
new GraphvizSSAGraphPrinter);
4343
}
44-
res.reset(new SSAGraghBuilderWithChecker(std::move(res)));
44+
45+
SSAGraphBuilder *previous_pass = res.release();
46+
res.reset(new SSAGraghBuilderWithChecker);
47+
res->Set<SSAGraphBuilder>("previous_pass", previous_pass);
4548

4649
return res;
4750
}

paddle/fluid/framework/details/ssa_graph_checker.h

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,19 @@ namespace details {
2424

2525
class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
2626
public:
27-
explicit SSAGraghBuilderWithChecker(
28-
std::unique_ptr<SSAGraphBuilder>&& builder)
29-
: builder_(std::move(builder)) {}
30-
3127
std::unique_ptr<ir::Graph> Apply(
3228
std::unique_ptr<ir::Graph> graph) const override {
33-
auto new_graph = builder_->Apply(std::move(graph));
29+
auto new_graph =
30+
Get<SSAGraphBuilder>("previous_pass").Apply(std::move(graph));
3431
PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
3532
return new_graph;
3633
}
3734

3835
int GetVarDeviceID(const std::string& var_name) const override {
39-
return builder_->GetVarDeviceID(var_name);
36+
return Get<SSAGraphBuilder>("previous_pass").GetVarDeviceID(var_name);
4037
}
4138

4239
bool IsValidGraph(const ir::Graph* graph) const;
43-
44-
private:
45-
std::unique_ptr<SSAGraphBuilder> builder_;
4640
};
4741

4842
} // namespace details

paddle/fluid/framework/details/ssa_graph_printer.h

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
#pragma once
1616

17+
#include <fstream>
1718
#include <iosfwd>
19+
#include <ostream>
1820
#include <string>
1921
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
2022

@@ -35,37 +37,21 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
3537

3638
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
3739
public:
38-
SSAGraghBuilderWithPrinter(std::ostream& sout,
39-
std::unique_ptr<SSAGraphPrinter>&& printer,
40-
std::unique_ptr<SSAGraphBuilder>&& builder)
41-
: printer_(std::move(printer)),
42-
builder_(std::move(builder)),
43-
stream_ref_(sout) {}
44-
45-
SSAGraghBuilderWithPrinter(std::unique_ptr<std::ostream>&& sout,
46-
std::unique_ptr<SSAGraphPrinter>&& printer,
47-
std::unique_ptr<SSAGraphBuilder>&& builder)
48-
: printer_(std::move(printer)),
49-
builder_(std::move(builder)),
50-
stream_ptr_(std::move(sout)),
51-
stream_ref_(*stream_ptr_) {}
52-
5340
std::unique_ptr<ir::Graph> Apply(
5441
std::unique_ptr<ir::Graph> graph) const override {
55-
auto new_graph = builder_->Apply(std::move(graph));
56-
printer_->Print(*new_graph, stream_ref_);
42+
auto new_graph =
43+
Get<SSAGraphBuilder>("previous_pass").Apply(std::move(graph));
44+
45+
std::unique_ptr<std::ostream> fout(
46+
new std::ofstream(Get<std::string>("debug_graphviz_path")));
47+
PADDLE_ENFORCE(fout->good());
48+
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*new_graph, *fout);
5749
return new_graph;
5850
}
5951

6052
int GetVarDeviceID(const std::string& var_name) const override {
61-
return builder_->GetVarDeviceID(var_name);
53+
return Get<SSAGraphBuilder>("previous_pass").GetVarDeviceID(var_name);
6254
}
63-
64-
private:
65-
std::unique_ptr<SSAGraphPrinter> printer_;
66-
std::unique_ptr<SSAGraphBuilder> builder_;
67-
std::unique_ptr<std::ostream> stream_ptr_;
68-
std::ostream& stream_ref_;
6955
};
7056

7157
} // namespace details

paddle/fluid/framework/ir/graph_viz_pass.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ namespace ir {
2323

2424
std::unique_ptr<ir::Graph> GraphVizPass::Apply(
2525
std::unique_ptr<ir::Graph> graph) const {
26-
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path_));
26+
const std::string graph_viz_path = Get<std::string>("graph_viz_path");
27+
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
2728
PADDLE_ENFORCE(fout->good());
2829
std::ostream& sout = *fout;
2930

@@ -61,6 +62,9 @@ std::unique_ptr<ir::Graph> GraphVizPass::Apply(
6162
sout << "}\n";
6263
return graph;
6364
}
65+
6466
} // namespace ir
6567
} // namespace framework
6668
} // namespace paddle
69+
70+
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass);

paddle/fluid/framework/ir/graph_viz_pass.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,8 @@ namespace ir {
2929

3030
class GraphVizPass : public Pass {
3131
public:
32-
explicit GraphVizPass(const std::string& graph_viz_path)
33-
: graph_viz_path_(graph_viz_path) {}
34-
3532
std::unique_ptr<ir::Graph> Apply(
3633
std::unique_ptr<ir::Graph> graph) const override;
37-
38-
private:
39-
const std::string graph_viz_path_;
4034
};
4135

4236
} // namespace ir

paddle/fluid/framework/ir/pass.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,12 @@ limitations under the License. */
1515
#include "paddle/fluid/framework/ir/pass.h"
1616

1717
namespace paddle {
18-
namespace framework {} // namespace framework
18+
namespace framework {
19+
namespace ir {
20+
PassRegistry& PassRegistry::Instance() {
21+
static PassRegistry g_pass_info_map;
22+
return g_pass_info_map;
23+
}
24+
} // namespace ir
25+
} // namespace framework
1926
} // namespace paddle

0 commit comments

Comments
 (0)