Skip to content

Commit e4d7d7a

Browse files
committed
pass refactoring
1 parent 142e832 commit e4d7d7a

17 files changed

+229
-81
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
244244
result.Set("vars", new GraphVars(places_.size()));
245245
result.Set("dep_vars", new GraphDepVars);
246246
result.Set("ops", new GraphOps);
247+
result.Set("sharded_var_device", new ShardedVarDevice);
247248

248249
// find send/recv vars so that we can place the distributed training
249250
// realted op in the place 0
@@ -276,11 +277,12 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
276277
// the block.
277278
is_forwarding = false;
278279
} else {
279-
int op_dev_id = GetOpDeviceID(node);
280+
int op_dev_id = GetOpDeviceID(result, node);
280281
if (op_dev_id != -1) { // This op only runs on one specific device.
281282
CreateComputationalOp(&result, node, op_dev_id);
282283
for (ir::Node *n : node->outputs) {
283-
var_name_on_devices_.emplace(n->Name(), op_dev_id);
284+
graph->Get<ShardedVarDevice>("sharded_var_device")
285+
.emplace(n->Name(), op_dev_id);
284286
}
285287
} else {
286288
// This op runs on all devices, and its output may have parameter's
@@ -317,7 +319,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
317319
case BuildStrategy::ReduceStrategy::kReduce:
318320
cur_device_id = GetAppropriateDeviceID({g_name});
319321
CreateReduceOp(&result, g_name, cur_device_id);
320-
var_name_on_devices_.emplace(g_name, cur_device_id);
322+
graph->Get<ShardedVarDevice>("sharded_var_device")
323+
.emplace(g_name, cur_device_id);
321324
bcast_var_name_set[cur_device_id].emplace(p_name);
322325
break;
323326
case BuildStrategy::ReduceStrategy::kAllReduce:
@@ -499,7 +502,8 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
499502
return is_pg_once;
500503
}
501504

502-
int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
505+
int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
506+
ir::Node *node) const {
503507
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
504508
return -1;
505509
}
@@ -512,15 +516,17 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
512516
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
513517

514518
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
515-
int dev_id = GetVarDeviceID(param_grad[1]);
519+
int dev_id = GetVarDeviceID(graph, param_grad[1]);
516520
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
517521
node->Op()->Type(), param_grad[0], param_grad[1]);
518522
return dev_id;
519523
}
520524

521-
int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
522-
auto got = var_name_on_devices_.find(varname);
523-
return got == var_name_on_devices_.end() ? -1 : got->second;
525+
int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
526+
const std::string &varname) const {
527+
auto &sharded_var_device = graph.Get<ShardedVarDevice>("sharded_var_device");
528+
auto got = sharded_var_device.find(varname);
529+
return got == sharded_var_device.end() ? -1 : got->second;
524530
}
525531

526532
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
@@ -625,20 +631,23 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
625631
if (node->Op()->Type() == "split_byref" ||
626632
node->Op()->Type() == "split_selected_rows") {
627633
// TODO(paddle-dev): getting the first var is not safe.
628-
op_dev_id = GetVarDeviceID(input_var_names[0]);
634+
op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
629635
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
630636
op_dev_id = GetAppropriateDeviceID(input_var_names);
631637
for (auto &varname : input_var_names) {
632-
var_name_on_devices_.emplace(varname, op_dev_id);
638+
result->Get<ShardedVarDevice>("sharded_var_device")
639+
.emplace(varname, op_dev_id);
633640
}
634641
}
635642
for (auto &varname : output_var_names) {
636-
var_name_on_devices_.emplace(varname, op_dev_id);
643+
result->Get<ShardedVarDevice>("sharded_var_device")
644+
.emplace(varname, op_dev_id);
637645
}
638646
} else if (node->Op()->Type() == "concat") {
639-
op_dev_id = GetVarDeviceID(input_var_names[0]);
647+
op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
640648
for (auto &varname : output_var_names) {
641-
var_name_on_devices_.emplace(varname, op_dev_id);
649+
result->Get<ShardedVarDevice>("sharded_var_device")
650+
.emplace(varname, op_dev_id);
642651
}
643652
} else {
644653
PADDLE_ENFORCE(
@@ -663,7 +672,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
663672
int op_dev_id = -1;
664673
if (node->Op()->Type() == "send") {
665674
// TODO(paddle-dev): getting the first var is not safe.
666-
op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
675+
op_dev_id = GetVarDeviceID(*result, node->inputs[0]->Name());
667676
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
668677
"This hack no longer holds, please fix.");
669678
// the variable name which contains .block means it was splited by
@@ -678,7 +687,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
678687
}
679688
op_dev_id = GetAppropriateDeviceID(input_var_names);
680689
for (auto &varname : input_var_names) {
681-
var_name_on_devices_.emplace(varname, op_dev_id);
690+
result->Get<ShardedVarDevice>("sharded_var_device")
691+
.emplace(varname, op_dev_id);
682692
}
683693
}
684694
} else if (node->Op()->Type() == "recv") {
@@ -688,7 +698,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
688698
}
689699
op_dev_id = GetAppropriateDeviceID(output_var_names);
690700
for (auto &varname : output_var_names) {
691-
var_name_on_devices_.emplace(varname, op_dev_id);
701+
result->Get<ShardedVarDevice>("sharded_var_device")
702+
.emplace(varname, op_dev_id);
692703
}
693704
} else {
694705
// send_barrier and fetch_barrier op can be scheduled on device 0
@@ -730,3 +741,6 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
730741
} // namespace details
731742
} // namespace framework
732743
} // namespace paddle
744+
745+
REGISTER_PASS(multi_device_pass,
746+
paddle::framework::details::MultiDevSSAGraphBuilder);

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
3434
public:
3535
std::unique_ptr<ir::Graph> Apply(
3636
std::unique_ptr<ir::Graph> graph) const override;
37-
int GetVarDeviceID(const std::string &varname) const override;
3837

3938
private:
4039
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
@@ -51,6 +50,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
5150
mutable platform::NCCLContextMap *nccl_ctxs_;
5251
#endif
5352

53+
int GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const;
54+
5455
bool IsScaleLossOp(ir::Node *node) const;
5556

5657
void CreateRPCOp(ir::Graph *result, ir::Node *node) const;
@@ -84,7 +85,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
8485
const std::string &og,
8586
std::unordered_set<std::string> *og_has_been_broadcast) const;
8687

87-
int GetOpDeviceID(ir::Node *node) const;
88+
int GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const;
8889

8990
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
9091

@@ -102,7 +103,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
102103
private:
103104
mutable BuildStrategy strategy_;
104105
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
105-
mutable std::unordered_map<std::string, int> var_name_on_devices_;
106106
mutable std::vector<int64_t> balance_vars_;
107107

108108
void SetCommunicationContext(OpHandleBase *op_handle,

paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
4040
ExecutionStrategy strategy, std::vector<Scope*> local_scopes,
4141
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
4242
std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
43+
44+
const ir::Graph& Graph() const { return underlying_executor_->Graph(); }
45+
4346
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
4447

4548
private:

paddle/fluid/framework/details/ssa_graph_builder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
4747
// unordered.
4848
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
4949

50+
typedef std::unordered_map<std::string, int> ShardedVarDevice;
51+
5052
class SSAGraphBuilder : public ir::Pass {
5153
public:
5254
SSAGraphBuilder() {}
5355
virtual ~SSAGraphBuilder() {}
5456

55-
virtual int GetVarDeviceID(const std::string &var_name) const = 0;
56-
5757
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
5858

5959
protected:

paddle/fluid/framework/details/ssa_graph_builder_factory.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
namespace paddle {
2222
namespace framework {
2323
namespace details {
24-
std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
25-
std::unique_ptr<SSAGraphBuilder> res(new MultiDevSSAGraphBuilder);
24+
std::unique_ptr<ir::Pass> ParallelExecutorPassManager::Create() {
25+
std::unique_ptr<ir::Pass> res(new MultiDevSSAGraphBuilder);
2626
res->SetNotOwned<std::vector<platform::Place>>("places", &places_);
2727
res->SetNotOwned<std::string>("loss_var_name", &loss_var_name_);
2828
res->SetNotOwned<std::unordered_set<std::string>>("params", &param_names_);
@@ -33,18 +33,18 @@ std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
3333
#endif
3434

3535
if (!strategy_.debug_graphviz_path_.empty()) {
36-
SSAGraphBuilder *previous_pass = res.release();
36+
ir::Pass *previous_pass = res.release();
3737
res.reset(new SSAGraghBuilderWithPrinter);
38-
res->Set<SSAGraphBuilder>("previous_pass", previous_pass);
38+
res->Set<ir::Pass>("previous_pass", previous_pass);
3939
res->SetNotOwned<std::string>("debug_graphviz_path",
4040
&strategy_.debug_graphviz_path_);
4141
res->Set<GraphvizSSAGraphPrinter>("graph_printer",
4242
new GraphvizSSAGraphPrinter);
4343
}
4444

45-
SSAGraphBuilder *previous_pass = res.release();
45+
ir::Pass *previous_pass = res.release();
4646
res.reset(new SSAGraghBuilderWithChecker);
47-
res->Set<SSAGraphBuilder>("previous_pass", previous_pass);
47+
res->Set<ir::Pass>("previous_pass", previous_pass);
4848

4949
return res;
5050
}

paddle/fluid/framework/details/ssa_graph_builder_factory.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ namespace framework {
2929
class Scope;
3030
namespace details {
3131

32-
class SSAGraphBuilderFactory {
32+
class ParallelExecutorPassManager {
3333
public:
34-
SSAGraphBuilderFactory(const std::vector<platform::Place>& places,
35-
const std::string& loss_var_name,
36-
const std::unordered_set<std::string>& param_names,
37-
const std::vector<Scope*>& local_scopes,
38-
const BuildStrategy& strategy)
34+
ParallelExecutorPassManager(
35+
const std::vector<platform::Place>& places,
36+
const std::string& loss_var_name,
37+
const std::unordered_set<std::string>& param_names,
38+
const std::vector<Scope*>& local_scopes, const BuildStrategy& strategy)
3939
: places_(places),
4040
loss_var_name_(loss_var_name),
4141
param_names_(param_names),
@@ -52,7 +52,7 @@ class SSAGraphBuilderFactory {
5252
}
5353
#endif
5454

55-
std::unique_ptr<SSAGraphBuilder> Create();
55+
std::unique_ptr<ir::Pass> Create();
5656

5757
private:
5858
std::vector<platform::Place> places_;

paddle/fluid/framework/details/ssa_graph_checker.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,6 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
8585
} // namespace details
8686
} // namespace framework
8787
} // namespace paddle
88+
89+
REGISTER_PASS(multi_device_check_pass,
90+
paddle::framework::details::SSAGraghBuilderWithChecker);

paddle/fluid/framework/details/ssa_graph_checker.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,11 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
2626
public:
2727
std::unique_ptr<ir::Graph> Apply(
2828
std::unique_ptr<ir::Graph> graph) const override {
29-
auto new_graph =
30-
Get<SSAGraphBuilder>("previous_pass").Apply(std::move(graph));
29+
auto new_graph = Get<ir::Pass>("previous_pass").Apply(std::move(graph));
3130
PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
3231
return new_graph;
3332
}
3433

35-
int GetVarDeviceID(const std::string& var_name) const override {
36-
return Get<SSAGraphBuilder>("previous_pass").GetVarDeviceID(var_name);
37-
}
38-
3934
bool IsValidGraph(const ir::Graph* graph) const;
4035
};
4136

paddle/fluid/framework/details/ssa_graph_executor.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ class SSAGraphExecutor {
3232

3333
virtual ~SSAGraphExecutor();
3434

35-
virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0;
35+
virtual const ir::Graph& Graph() const = 0;
36+
37+
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
3638
};
3739
} // namespace details
3840
} // namespace framework

paddle/fluid/framework/details/ssa_graph_printer.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,6 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
8181
} // namespace details
8282
} // namespace framework
8383
} // namespace paddle
84+
85+
REGISTER_PASS(multi_device_print_pass,
86+
paddle::framework::details::SSAGraghBuilderWithPrinter);

0 commit comments

Comments
 (0)