Skip to content

Commit c3f6e0e

Browse files
committed
add namespace to Graph
1 parent 0b3465d commit c3f6e0e

16 files changed

+82
-71
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
6868
}
6969
}
7070

71-
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node,
71+
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
72+
ir::Node *node,
7273
size_t place_id) const {
7374
auto p = places_[place_id];
7475
auto *op_handle = result->Get<GraphOps>("ops").back().get();
@@ -192,8 +193,9 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
192193
// to parameter/gradients before optimizer ops, topo sort is insufficient. (
193194
// some optimizer ops might not depend on any nodes), we manually move all
194195
// optimizer nodes after last backward nodes.
195-
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const Graph &graph) {
196-
std::vector<ir::Node *> ret = ir::TopologySort(graph);
196+
// However, the assumption by SSAGraphBuilder should be relaxed in the future.
197+
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
198+
std::vector<ir::Node *> ret = ir::TopologySortOperations(graph);
197199
size_t last_backward = 0;
198200
std::vector<ir::Node *> optimize_ops;
199201
std::vector<ir::Node *> sorted_ret;
@@ -232,8 +234,8 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const Graph &graph) {
232234
return sorted_ret;
233235
}
234236

235-
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
236-
std::unique_ptr<Graph> graph) const {
237+
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
238+
std::unique_ptr<ir::Graph> graph) const {
237239
// Rebuild the graph structure.
238240
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
239241
auto nodes = std::move(graph->nodes);
@@ -245,7 +247,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
245247
}
246248
}
247249

248-
Graph &result = *graph;
250+
ir::Graph &result = *graph;
249251
std::unordered_set<std::string> og_has_been_broadcast;
250252

251253
// We cannot invoke resize. It is a bug of GCC 4.8
@@ -397,7 +399,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
397399
#endif
398400
}
399401

400-
void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
402+
void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
401403
const std::string &p_name,
402404
size_t src_dev_id) const {
403405
#ifdef PADDLE_WITH_CUDA
@@ -427,7 +429,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
427429
}
428430
}
429431

430-
void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
432+
void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
431433
ir::Node *node,
432434
int dev_id) const {
433435
result->Get<GraphOps>("ops").emplace_back(
@@ -436,7 +438,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
436438
CreateOpHandleIOs(result, node, dev_id);
437439
}
438440

439-
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
441+
void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
440442
const std::string &og) const {
441443
#ifdef PADDLE_WITH_CUDA
442444
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
@@ -466,7 +468,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
466468
}
467469

468470
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
469-
Graph *result, const std::vector<std::string> &datas) const {
471+
ir::Graph *result, const std::vector<std::string> &datas) const {
470472
#ifdef PADDLE_WITH_CUDA
471473
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
472474
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
@@ -529,7 +531,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
529531
return got == var_name_on_devices_.end() ? -1 : got->second;
530532
}
531533

532-
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
534+
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
533535
for (size_t i = 0; i < places_.size(); ++i) {
534536
// Insert ScaleCost OpHandle
535537
#ifdef PADDLE_WITH_CUDA
@@ -559,7 +561,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
559561
}
560562
}
561563

562-
void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
564+
void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
563565
ir::Node *node,
564566
size_t num_places) const {
565567
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
@@ -571,7 +573,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
571573
}
572574
}
573575

574-
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
576+
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
575577
const std::string &og,
576578
int dst_dev_id) const {
577579
#ifdef PADDLE_WITH_CUDA
@@ -604,7 +606,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
604606

605607
// Find the first occurence of `prev_op_name` and make current `op` depend
606608
// on it.
607-
void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
609+
void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
608610
const std::string &prev_op_name) const {
609611
for (auto &prev_op : result->Get<GraphOps>("ops")) {
610612
if (prev_op->Name() == prev_op_name) {
@@ -617,7 +619,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
617619
}
618620
}
619621

620-
void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
622+
void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
621623
ir::Node *node) const {
622624
int op_dev_id = -1;
623625
std::vector<std::string> input_var_names;
@@ -664,7 +666,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
664666
}
665667

666668
// Create RPC related op handles that connects its in ops and out ops.
667-
void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const {
669+
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
670+
ir::Node *node) const {
668671
int op_dev_id = -1;
669672
if (node->Op()->Type() == "send") {
670673
op_dev_id = GetVarDeviceID(node->inputs[0]->Name());

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
4646
const std::vector<Scope *> &local_scopes,
4747
const BuildStrategy &strategy);
4848
#endif
49-
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override;
49+
std::unique_ptr<ir::Graph> Apply(
50+
std::unique_ptr<ir::Graph> graph) const override;
5051
int GetVarDeviceID(const std::string &varname) const override;
5152

5253
private:
53-
void CreateOpHandleIOs(Graph *result, ir::Node *node, size_t device_id) const;
54+
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
55+
size_t device_id) const;
5456

5557
private:
5658
std::string loss_var_name_;
@@ -64,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
6466

6567
bool IsScaleLossOp(ir::Node *node) const;
6668

67-
void CreateRPCOp(Graph *result, ir::Node *node) const;
68-
void CreateDistTrainOp(Graph *result, ir::Node *node) const;
69+
void CreateRPCOp(ir::Graph *result, ir::Node *node) const;
70+
void CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
6971

7072
/**
7173
* Is this operator as the end-point operator before/after send operator.
@@ -79,29 +81,30 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
7981
std::vector<std::string> FindDistTrainRecvVars(
8082
const std::vector<std::unique_ptr<ir::Node>> &nodes) const;
8183

82-
void ConnectOp(Graph *result, OpHandleBase *op,
84+
void ConnectOp(ir::Graph *result, OpHandleBase *op,
8385
const std::string &prev_op_name) const;
8486

85-
void CreateComputationalOps(Graph *result, ir::Node *node,
87+
void CreateComputationalOps(ir::Graph *result, ir::Node *node,
8688
size_t num_places) const;
8789

88-
void CreateScaleLossGradOp(Graph *result) const;
89-
VarHandle *CreateReduceOp(Graph *result, const std::string &og,
90+
void CreateScaleLossGradOp(ir::Graph *result) const;
91+
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
9092
int dst_dev_id) const;
91-
void CreateComputationalOp(Graph *result, ir::Node *node, int dev_id) const;
93+
void CreateComputationalOp(ir::Graph *result, ir::Node *node,
94+
int dev_id) const;
9295

9396
bool IsParameterGradientOnce(
9497
const std::string &og,
9598
std::unordered_set<std::string> *og_has_been_broadcast) const;
9699

97100
int GetOpDeviceID(ir::Node *node) const;
98101

99-
void InsertAllReduceOp(Graph *result, const std::string &og) const;
102+
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
100103

101-
void InsertDataBalanceOp(Graph *result,
104+
void InsertDataBalanceOp(ir::Graph *result,
102105
const std::vector<std::string> &datas) const;
103106

104-
void CreateBroadcastOp(Graph *result, const std::string &p_name,
107+
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
105108
size_t src_dev_id) const;
106109

107110
bool IsSparseGradient(const std::string &og) const;

paddle/fluid/framework/details/ssa_graph_builder.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
namespace paddle {
1818
namespace framework {
1919
namespace details {
20-
void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
20+
void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
2121
for (auto &var_map : graph->Get<GraphVars>("vars")) {
2222
for (auto &name_pair : var_map) {
2323
if (name_pair.second.size() <= 1) {
@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
6060
}
6161

6262
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
63-
Graph *graph, ir::Node *node, const platform::Place &place,
63+
ir::Graph *graph, ir::Node *node, const platform::Place &place,
6464
size_t place_offset) {
6565
auto &var_holders = graph->Get<GraphVars>("vars")[place_offset];
6666
auto &var_holder = var_holders[node->Name()];
@@ -81,7 +81,7 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
8181
return var;
8282
}
8383

84-
void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
84+
void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
8585
ir::Node *new_node,
8686
const platform::Place &place,
8787
size_t place_offset) {
@@ -93,7 +93,7 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
9393
op_handle->AddOutput(var);
9494
}
9595

96-
void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
96+
void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) {
9797
for (auto &op : graph->Get<GraphOps>("ops")) {
9898
if (!op->Outputs().empty()) {
9999
continue;

paddle/fluid/framework/details/ssa_graph_builder.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,19 @@ class SSAGraphBuilder : public ir::Pass {
6464
*
6565
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
6666
*/
67-
static void PolishGraphToSupportDataHazards(Graph *graph);
67+
static void PolishGraphToSupportDataHazards(ir::Graph *graph);
6868

69-
static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, ir::Node *node,
69+
static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
7070
const platform::Place &place,
7171
size_t place_offset);
7272

7373
// Add an output variable (each_var_name, place, place_offset) to op_handle,
7474
// which belongs to graph
75-
static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
75+
static void CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
7676
ir::Node *new_node, const platform::Place &place,
7777
size_t place_offset);
7878

79-
static void AddOutputToLeafOps(Graph *graph);
79+
static void AddOutputToLeafOps(ir::Graph *graph);
8080
};
8181
} // namespace details
8282
} // namespace framework

paddle/fluid/framework/details/ssa_graph_checker.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace paddle {
2020
namespace framework {
2121
namespace details {
2222

23-
bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const {
23+
bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
2424
std::unordered_map<OpHandleBase *, size_t> pending_ops;
2525
std::unordered_set<VarHandleBase *> pending_vars;
2626
std::unordered_set<VarHandleBase *> ready_vars;

paddle/fluid/framework/details/ssa_graph_checker.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
2828
std::unique_ptr<SSAGraphBuilder>&& builder)
2929
: builder_(std::move(builder)) {}
3030

31-
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
31+
std::unique_ptr<ir::Graph> Apply(
32+
std::unique_ptr<ir::Graph> graph) const override {
3233
auto new_graph = builder_->Apply(std::move(graph));
3334
PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
3435
return new_graph;
@@ -38,7 +39,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
3839
return builder_->GetVarDeviceID(var_name);
3940
}
4041

41-
bool IsValidGraph(const Graph* graph) const;
42+
bool IsValidGraph(const ir::Graph* graph) const;
4243

4344
private:
4445
std::unique_ptr<SSAGraphBuilder> builder_;

paddle/fluid/framework/details/ssa_graph_printer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace framework {
2121
namespace details {
2222

2323
template <typename Callback>
24-
static inline void IterAllVar(const Graph &graph, Callback callback) {
24+
static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
2525
for (auto &each : graph.Get<GraphVars>("vars")) {
2626
for (auto &pair1 : each) {
2727
for (auto &pair2 : pair1.second) {
@@ -35,7 +35,7 @@ static inline void IterAllVar(const Graph &graph, Callback callback) {
3535
}
3636
}
3737

38-
void GraphvizSSAGraphPrinter::Print(const Graph &graph,
38+
void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
3939
std::ostream &sout) const {
4040
size_t var_id = 0;
4141
std::unordered_map<const VarHandleBase *, size_t> vars;

paddle/fluid/framework/details/ssa_graph_printer.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ namespace details {
2525
class SSAGraphPrinter {
2626
public:
2727
virtual ~SSAGraphPrinter() {}
28-
virtual void Print(const Graph& graph, std::ostream& sout) const = 0;
28+
virtual void Print(const ir::Graph& graph, std::ostream& sout) const = 0;
2929
};
3030

3131
class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
3232
public:
33-
void Print(const Graph& graph, std::ostream& sout) const override;
33+
void Print(const ir::Graph& graph, std::ostream& sout) const override;
3434
};
3535

3636
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
@@ -50,7 +50,8 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
5050
stream_ptr_(std::move(sout)),
5151
stream_ref_(*stream_ptr_) {}
5252

53-
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
53+
std::unique_ptr<ir::Graph> Apply(
54+
std::unique_ptr<ir::Graph> graph) const override {
5455
auto new_graph = builder_->Apply(std::move(graph));
5556
printer_->Print(*new_graph, stream_ref_);
5657
return new_graph;

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ namespace framework {
2121
namespace details {
2222
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
2323
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
24-
const std::vector<platform::Place> &places, std::unique_ptr<Graph> &&graph)
24+
const std::vector<platform::Place> &places,
25+
std::unique_ptr<ir::Graph> &&graph)
2526
: graph_(std::move(graph)),
2627
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
2728
: nullptr),

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
4040
ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
4141
const std::vector<Scope *> &local_scopes,
4242
const std::vector<platform::Place> &places,
43-
std::unique_ptr<Graph> &&graph);
43+
std::unique_ptr<ir::Graph> &&graph);
4444

4545
// Run a SSAGraph by a thread pool
4646
// Use topological sort algorithm
@@ -53,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
5353
details::OpHandleBase *op);
5454

5555
private:
56-
std::unique_ptr<Graph> graph_;
56+
std::unique_ptr<ir::Graph> graph_;
5757
std::unique_ptr<::ThreadPool> pool_;
5858
std::vector<Scope *> local_scopes_;
5959
std::vector<platform::Place> places_;

0 commit comments

Comments
 (0)