Skip to content

Commit ab72d28

Browse files
committed
clean up and correctness check
1 parent aa1085d commit ab72d28

16 files changed

+184
-92
lines changed

doc/fluid/design/ir/draft.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,12 @@ can also fuse some `Graph`'s `Node`s.
7575
class Pass {
7676
public:
7777

78-
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
78+
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const {
79+
// Some correctness check.
80+
auto new_graph = ApplyImpl(std::move(graph));
81+
// Some correctness check.
82+
return new_graph;
83+
}
7984

8085
// Get a reference to the attributed previously set.
8186
template <typename AttrType>
@@ -89,6 +94,9 @@ class Pass {
8994
// should delete the attribute.
9095
template <typename AttrType>
9196
void SetNotOwned(const std::string &attr_name, AttrType *attr);
97+
98+
protected:
99+
virtual std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const = 0;
92100
};
93101

94102
// In my_pass.cc

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,22 @@ namespace paddle {
3434
namespace framework {
3535
namespace details {
3636

37+
static const char kLossVarName[] = "loss_var_name";
38+
static const char kPlaces[] = "places";
39+
static const char kParams[] = "params";
40+
static const char kLocalScopes[] = "local_scopes";
41+
static const char kStrategy[] = "strategy";
42+
3743
void MultiDevSSAGraphBuilder::Init() const {
38-
loss_var_name_ = Get<const std::string>("loss_var_name");
39-
places_ = Get<const std::vector<platform::Place>>("places");
40-
local_scopes_ = Get<const std::vector<Scope *>>("local_scopes");
41-
strategy_ = Get<const BuildStrategy>("strategy");
44+
loss_var_name_ = Get<const std::string>(kLossVarName);
45+
places_ = Get<const std::vector<platform::Place>>(kPlaces);
46+
local_scopes_ = Get<const std::vector<Scope *>>(kLocalScopes);
47+
strategy_ = Get<const BuildStrategy>(kStrategy);
4248
#ifdef PADDLE_WITH_CUDA
4349
nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs");
4450
#endif
4551

46-
for (auto &p : Get<const std::unordered_set<std::string>>("params")) {
52+
for (auto &p : Get<const std::unordered_set<std::string>>(kParams)) {
4753
grad_names_.insert(GradVarName(p));
4854
}
4955
balance_vars_.resize(places_.size(), 0);
@@ -58,7 +64,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
5864
ir::Node *node,
5965
size_t place_id) const {
6066
auto p = places_[place_id];
61-
auto *op_handle = result->Get<GraphOps>("ops").back().get();
67+
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
6268
op_handle->SetDeviceContext(p,
6369
platform::DeviceContextPool::Instance().Get(p));
6470

@@ -225,7 +231,7 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
225231
return sorted_ret;
226232
}
227233

228-
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
234+
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
229235
std::unique_ptr<ir::Graph> graph) const {
230236
Init();
231237
// Give the topology sort order and rebuild the graph structure.
@@ -241,10 +247,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
241247
std::unordered_set<std::string> og_has_been_broadcast;
242248

243249
// We cannot invoke resize. It is a bug of GCC 4.8
244-
result.Set("vars", new GraphVars(places_.size()));
245-
result.Set("dep_vars", new GraphDepVars);
246-
result.Set("ops", new GraphOps);
247-
result.Set("sharded_var_device", new ShardedVarDevice);
250+
result.Set(kGraphVars, new GraphVars(places_.size()));
251+
result.Set(kGraphDepVars, new GraphDepVars);
252+
result.Set(kGraphOps, new GraphOps);
253+
result.Set(kShardedVarDevice, new ShardedVarDevice);
248254

249255
// find send/recv vars so that we can place the distributed training
250256
// realted op in the place 0
@@ -281,7 +287,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
281287
if (op_dev_id != -1) { // This op only runs on one specific device.
282288
CreateComputationalOp(&result, node, op_dev_id);
283289
for (ir::Node *n : node->outputs) {
284-
graph->Get<ShardedVarDevice>("sharded_var_device")
290+
graph->Get<ShardedVarDevice>(kShardedVarDevice)
285291
.emplace(n->Name(), op_dev_id);
286292
}
287293
} else {
@@ -319,7 +325,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
319325
case BuildStrategy::ReduceStrategy::kReduce:
320326
cur_device_id = GetAppropriateDeviceID({g_name});
321327
CreateReduceOp(&result, g_name, cur_device_id);
322-
graph->Get<ShardedVarDevice>("sharded_var_device")
328+
graph->Get<ShardedVarDevice>(kShardedVarDevice)
323329
.emplace(g_name, cur_device_id);
324330
bcast_var_name_set[cur_device_id].emplace(p_name);
325331
break;
@@ -406,16 +412,16 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
406412
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
407413
local_scopes_, places_);
408414
#endif
409-
result->Get<GraphOps>("ops").emplace_back(op_handle);
415+
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
410416

411417
auto *in =
412-
result->Get<GraphVars>("vars").at(src_dev_id).at(p_name).back().get();
418+
result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back().get();
413419
op_handle->AddInput(in);
414420

415421
for (size_t i = 0; i < places_.size(); ++i) {
416422
auto &p = places_[i];
417423
SetCommunicationContext(op_handle, p);
418-
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
424+
auto &vars = result->Get<GraphVars>(kGraphVars).at(i).at(p_name);
419425
auto *out_var = new VarHandle(
420426
result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(),
421427
i, p_name, p);
@@ -427,7 +433,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
427433
void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
428434
ir::Node *node,
429435
int dev_id) const {
430-
result->Get<GraphOps>("ops").emplace_back(
436+
result->Get<GraphOps>(kGraphOps).emplace_back(
431437
new ComputationOpHandle(result->CreateOpNode(node->Op()),
432438
local_scopes_[dev_id], places_[dev_id]));
433439
CreateOpHandleIOs(result, node, dev_id);
@@ -436,20 +442,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
436442
void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
437443
const std::string &og) const {
438444
#ifdef PADDLE_WITH_CUDA
439-
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
445+
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
440446
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
441447
local_scopes_, places_, nccl_ctxs_));
442448
#else
443-
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
449+
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
444450
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
445451
local_scopes_, places_));
446452
#endif
447-
auto *op_handle = result->Get<GraphOps>("ops").back().get();
453+
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
448454

449455
for (size_t i = 0; i < places_.size(); ++i) {
450456
auto &p = places_[i];
451457
SetCommunicationContext(op_handle, p);
452-
auto &vars = result->Get<GraphVars>("vars")[i][og];
458+
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
453459
PADDLE_ENFORCE(!vars.empty());
454460
auto &prev_grad = vars.back();
455461
op_handle->AddInput(prev_grad.get());
@@ -465,20 +471,20 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
465471
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
466472
ir::Graph *result, const std::vector<std::string> &datas) const {
467473
#ifdef PADDLE_WITH_CUDA
468-
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
474+
result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
469475
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
470476
local_scopes_, places_, nccl_ctxs_));
471477
#else
472-
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
478+
result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
473479
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
474480
local_scopes_, places_));
475481
#endif
476-
auto *op_handle = result->Get<GraphOps>("ops").back().get();
482+
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
477483
for (size_t i = 0; i < places_.size(); ++i) {
478484
auto &p = places_[i];
479485
SetCommunicationContext(op_handle, p);
480486
for (const std::string &d_name : datas) {
481-
auto &vars = result->Get<GraphVars>("vars")[i][d_name];
487+
auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name];
482488
PADDLE_ENFORCE(!vars.empty());
483489
op_handle->AddInput(vars.back().get());
484490
auto var = new VarHandle(
@@ -524,7 +530,7 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
524530

525531
int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
526532
const std::string &varname) const {
527-
auto &sharded_var_device = graph.Get<ShardedVarDevice>("sharded_var_device");
533+
auto &sharded_var_device = graph.Get<ShardedVarDevice>(kShardedVarDevice);
528534
auto got = sharded_var_device.find(varname);
529535
return got == sharded_var_device.end() ? -1 : got->second;
530536
}
@@ -544,7 +550,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
544550
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
545551
local_scopes_.size(), local_scopes_[i], places_[i],
546552
communication_dev_ctx);
547-
result->Get<GraphOps>("ops").emplace_back(op_handle);
553+
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
548554

549555
// FIXME: Currently ScaleLossGradOp only use device_count as scale
550556
// factor. So it does not depend on any other operators.
@@ -565,7 +571,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
565571
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
566572
auto p = places_[scope_idx];
567573
auto s = local_scopes_[scope_idx];
568-
result->Get<GraphOps>("ops").emplace_back(
574+
result->Get<GraphOps>(kGraphOps).emplace_back(
569575
new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
570576
CreateOpHandleIOs(result, node, scope_idx);
571577
}
@@ -575,25 +581,25 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
575581
const std::string &og,
576582
int dst_dev_id) const {
577583
#ifdef PADDLE_WITH_CUDA
578-
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
584+
result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
579585
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
580586
local_scopes_, places_, nccl_ctxs_));
581587
#else
582-
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
588+
result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
583589
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
584590
local_scopes_, places_));
585591
#endif
586-
auto *op_handle = result->Get<GraphOps>("ops").back().get();
592+
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
587593

588594
for (size_t i = 0; i < places_.size(); ++i) {
589595
auto &p = places_[i];
590596
SetCommunicationContext(op_handle, p);
591-
auto &vars = result->Get<GraphVars>("vars")[i][og];
597+
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
592598
PADDLE_ENFORCE(!vars.empty());
593599
auto &prev_grad = vars.back();
594600
op_handle->AddInput(prev_grad.get());
595601
}
596-
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
602+
auto &vars = result->Get<GraphVars>(kGraphVars)[dst_dev_id][og];
597603
auto var =
598604
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
599605
vars.size(), dst_dev_id, og, places_[dst_dev_id]);
@@ -606,11 +612,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
606612
// on it.
607613
void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
608614
const std::string &prev_op_name) const {
609-
for (auto &prev_op : result->Get<GraphOps>("ops")) {
615+
for (auto &prev_op : result->Get<GraphOps>(kGraphOps)) {
610616
if (prev_op->Name() == prev_op_name) {
611617
auto *dep_var = new DummyVarHandle(result->CreateControlDepVar());
612618
prev_op->AddOutput(dep_var);
613-
result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
619+
result->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
614620
op->AddInput(dep_var);
615621
}
616622
}
@@ -635,18 +641,18 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
635641
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
636642
op_dev_id = GetAppropriateDeviceID(input_var_names);
637643
for (auto &varname : input_var_names) {
638-
result->Get<ShardedVarDevice>("sharded_var_device")
644+
result->Get<ShardedVarDevice>(kShardedVarDevice)
639645
.emplace(varname, op_dev_id);
640646
}
641647
}
642648
for (auto &varname : output_var_names) {
643-
result->Get<ShardedVarDevice>("sharded_var_device")
649+
result->Get<ShardedVarDevice>(kShardedVarDevice)
644650
.emplace(varname, op_dev_id);
645651
}
646652
} else if (node->Op()->Type() == "concat") {
647653
op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
648654
for (auto &varname : output_var_names) {
649-
result->Get<ShardedVarDevice>("sharded_var_device")
655+
result->Get<ShardedVarDevice>(kShardedVarDevice)
650656
.emplace(varname, op_dev_id);
651657
}
652658
} else {
@@ -661,7 +667,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
661667

662668
CreateComputationalOp(result, node, op_dev_id);
663669
if (node->Op()->Type() == "concat") {
664-
ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
670+
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(),
665671
"fetch_barrier");
666672
}
667673
}
@@ -687,7 +693,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
687693
}
688694
op_dev_id = GetAppropriateDeviceID(input_var_names);
689695
for (auto &varname : input_var_names) {
690-
result->Get<ShardedVarDevice>("sharded_var_device")
696+
result->Get<ShardedVarDevice>(kShardedVarDevice)
691697
.emplace(varname, op_dev_id);
692698
}
693699
}
@@ -698,7 +704,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
698704
}
699705
op_dev_id = GetAppropriateDeviceID(output_var_names);
700706
for (auto &varname : output_var_names) {
701-
result->Get<ShardedVarDevice>("sharded_var_device")
707+
result->Get<ShardedVarDevice>(kShardedVarDevice)
702708
.emplace(varname, op_dev_id);
703709
}
704710
} else {
@@ -709,17 +715,17 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
709715
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
710716
node->Op()->Type());
711717

712-
result->Get<GraphOps>("ops").emplace_back(new RPCOpHandle(
718+
result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle(
713719
result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
714720
node->Op()->Type(), places_[op_dev_id]));
715721

716722
if (node->Op()->Type() == "send_barrier") {
717-
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "send");
723+
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(), "send");
718724
} else if (node->Op()->Type() == "recv") {
719-
ConnectOp(result, result->Get<GraphOps>("ops").back().get(),
725+
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(),
720726
"send_barrier");
721727
} else if (node->Op()->Type() == "fetch_barrier") {
722-
ConnectOp(result, result->Get<GraphOps>("ops").back().get(), "recv");
728+
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(), "recv");
723729
} else if (node->Op()->Type() == "send") {
724730
// do nothing
725731
} else {
@@ -743,4 +749,9 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
743749
} // namespace paddle
744750

745751
REGISTER_PASS(multi_device_pass,
746-
paddle::framework::details::MultiDevSSAGraphBuilder);
752+
paddle::framework::details::MultiDevSSAGraphBuilder)
753+
.RequirePassAttr(paddle::framework::details::kLossVarName)
754+
.RequirePassAttr(paddle::framework::details::kPlaces)
755+
.RequirePassAttr(paddle::framework::details::kParams)
756+
.RequirePassAttr(paddle::framework::details::kLocalScopes)
757+
.RequirePassAttr(paddle::framework::details::kStrategy);

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ class Scope;
3131
namespace details {
3232

3333
class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
34-
public:
35-
std::unique_ptr<ir::Graph> Apply(
34+
protected:
35+
std::unique_ptr<ir::Graph> ApplyImpl(
3636
std::unique_ptr<ir::Graph> graph) const override;
3737

3838
private:

paddle/fluid/framework/details/ssa_graph_builder.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace paddle {
1818
namespace framework {
1919
namespace details {
2020
void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
21-
for (auto &var_map : graph->Get<GraphVars>("vars")) {
21+
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
2222
for (auto &name_pair : var_map) {
2323
if (name_pair.second.size() <= 1) {
2424
continue;
@@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
5050
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
5151
read_op->AddOutput(dep_var);
5252
write_op->AddInput(dep_var);
53-
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
53+
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
5454
}
5555
}
5656
}
@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
6060
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
6161
ir::Graph *graph, ir::Node *node, const platform::Place &place,
6262
size_t place_offset) {
63-
auto &var_holders = graph->Get<GraphVars>("vars")[place_offset];
63+
auto &var_holders = graph->Get<GraphVars>(kGraphVars)[place_offset];
6464
auto &var_holder = var_holders[node->Name()];
6565
VarHandle *var = nullptr;
6666
if (var_holder.empty()) {
@@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
8383
ir::Node *new_node,
8484
const platform::Place &place,
8585
size_t place_offset) {
86-
auto &vars = graph->Get<GraphVars>("vars")[place_offset][new_node->Name()];
86+
auto &vars =
87+
graph->Get<GraphVars>(kGraphVars)[place_offset][new_node->Name()];
8788
size_t version = vars.size();
8889
auto var =
8990
new VarHandle(new_node, version, place_offset, new_node->Name(), place);
@@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
9293
}
9394

9495
void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) {
95-
for (auto &op : graph->Get<GraphOps>("ops")) {
96+
for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
9697
if (!op->Outputs().empty()) {
9798
continue;
9899
}
99100
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
100-
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
101+
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
101102
op->AddOutput(dummy_leaf);
102103
}
103104
}

0 commit comments

Comments
 (0)