Skip to content

Commit 268e9dc

Browse files
committed
polish code
1 parent ceefbf3 commit 268e9dc

File tree

3 files changed

+22
-20
lines changed

3 files changed

+22
-20
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
8484
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
8585
const ProgramDesc &program) const {
8686
std::vector<std::string> send_vars;
87+
// since parameters are all in block 0,
88+
// it's enough to only scan send ops in block 0
8789
for (auto *op : program.Block(0).AllOps()) {
88-
if (op->Type() == "send_vars" || op->Type() == "send") {
90+
// TODO(Yancey1989): use a graceful method to find send op,
91+
// instead of the the hard code string
92+
if (op->Type() == "send_vars") {
8993
auto op_vars = op->InputArgumentNames();
9094
send_vars.reserve(send_vars.size() +
9195
std::distance(op_vars.begin(), op_vars.end()));
@@ -99,7 +103,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
99103
const ProgramDesc &program) const {
100104
std::vector<std::string> recv_vars;
101105
for (auto *op : program.Block(0).AllOps()) {
102-
if (op->Type() == "recv" || op->Type() == "send") {
106+
// TODO(Yancey1989): use a graceful method to find recv op,
107+
// instead of the hard code string
108+
if (op->Type() == "recv") {
103109
auto op_vars = op->OutputArgumentNames();
104110
recv_vars.reserve(recv_vars.size() +
105111
std::distance(op_vars.begin(), op_vars.end()));
@@ -122,6 +128,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
122128
auto checker = [](const std::vector<std::string> &opvars,
123129
const std::vector<std::string> &rpc_vars) -> bool {
124130
for (auto &var : opvars) {
131+
// a variable name with the suffix `.block` means it's a splited
132+
// variable by (DistributeTranspiler)
133+
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
125134
if (var.find(".block") != std::string::npos &&
126135
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
127136
return true;
@@ -130,13 +139,8 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
130139
return false;
131140
};
132141

133-
if (op.Type() == "split" || op.Type() == "split_byref" ||
134-
op.Type() == "split_selected_rows") {
135-
return checker(op.OutputArgumentNames(), send_vars);
136-
} else if (op.Type() == "concat") {
137-
return checker(op.InputArgumentNames(), recv_vars);
138-
}
139-
142+
return checker(op.OutputArgumentNames(), send_vars) ||
143+
checker(op.InputArgumentNames(), recv_vars);
140144
return false;
141145
}
142146

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
3434
const std::string ep_val = ep;
3535
const std::string var_name_val = var_name;
3636
const framework::Scope* p_scope = &scope;
37-
const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val);
37+
const auto ch = GetChannel(ep_val);
3838

3939
framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch,
4040
this] {
@@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
8888
const std::string ep_val = ep;
8989
const std::string var_name_val = var_name;
9090
const framework::Scope* p_scope = &scope;
91-
const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val);
91+
const auto ch = GetChannel(ep_val);
9292

9393
framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch,
9494
this] {
@@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
132132
const std::string in_var_name_val = in_var_name;
133133
const std::string out_var_name_val = out_var_name;
134134
const framework::Scope* p_scope = &scope;
135-
const auto ch = GetChannel(ep_val, ep_val + ":" + in_var_name_val);
135+
const auto ch = GetChannel(ep_val);
136136

137137
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
138138
time_out, ch, this] {
@@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
165165
}
166166

167167
void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
168-
const auto ch = GetChannel(ep, ep);
168+
const auto ch = GetChannel(ep);
169169

170170
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
171171
s->Prepare(time_out);
@@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
178178
}
179179

180180
void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
181-
const auto ch = GetChannel(ep, ep);
181+
const auto ch = GetChannel(ep);
182182
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
183183
s->Prepare(time_out);
184184

@@ -248,10 +248,9 @@ bool RPCClient::Proceed() {
248248
delete c;
249249
return true;
250250
}
251-
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep,
252-
const std::string& key) {
251+
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
253252
std::unique_lock<std::mutex> lock(mutex_);
254-
auto it = channels_.find(key);
253+
auto it = channels_.find(ep);
255254
if (it != channels_.end()) {
256255
return it->second;
257256
}
@@ -263,7 +262,7 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep,
263262

264263
auto ch =
265264
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
266-
channels_[key] = ch;
265+
channels_[ep] = ch;
267266
return ch;
268267
}
269268

paddle/fluid/operators/detail/grpc_client.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,7 @@ class RPCClient {
191191

192192
private:
193193
bool Proceed();
194-
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep,
195-
const std::string& key);
194+
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
196195

197196
private:
198197
grpc::CompletionQueue cq_;

0 commit comments

Comments
 (0)