Skip to content

Commit 58be41f

Browse files
authored
Merge pull request #7608 from typhoonzero/distributed_split_selectedrows
Enhance distributed train performance
2 parents b7b5de7 + 0aff136 commit 58be41f

File tree

7 files changed

+94
-83
lines changed

7 files changed

+94
-83
lines changed

paddle/operators/detail/grpc_client.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,6 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
6363
sendrecv::VariableMessage req;
6464
req.set_varname(var_name);
6565

66-
auto* var = scope.FindVar(var_name);
67-
SerializeToMessage(var_name, var, ctx, &req);
68-
6966
// varhandle
7067
VarHandle var_h;
7168
var_h.ep = ep;

paddle/operators/detail/grpc_server.cc

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ class RequestBase {
3636

3737
CallStatus Status() { return status_; }
3838
void SetStatus(CallStatus status) { status_ = status; }
39-
virtual std::string GetReqName() { assert(false); }
39+
virtual std::string GetReqName() {
40+
assert(false);
41+
return "";
42+
}
4043

4144
protected:
4245
grpc::ServerContext ctx_;
@@ -80,11 +83,13 @@ class RequestGet final : public RequestBase {
8083
public:
8184
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
8285
grpc::ServerCompletionQueue* cq, framework::Scope* scope,
83-
const platform::DeviceContext* dev_ctx)
86+
const platform::DeviceContext* dev_ctx,
87+
SimpleBlockQueue<char>* queue)
8488
: RequestBase(service, cq),
8589
responder_(&ctx_),
8690
scope_(scope),
87-
dev_ctx_(dev_ctx) {
91+
dev_ctx_(dev_ctx),
92+
queue_(queue) {
8893
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
8994
}
9095

@@ -100,6 +105,7 @@ class RequestGet final : public RequestBase {
100105
// TODO(gongwb): check var's info.
101106
responder_.Finish(reply_, grpc::Status::OK, this);
102107
status_ = FINISH;
108+
queue_->Push('c');
103109
}
104110

105111
protected:
@@ -108,8 +114,15 @@ class RequestGet final : public RequestBase {
108114
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
109115
framework::Scope* scope_;
110116
const platform::DeviceContext* dev_ctx_;
117+
SimpleBlockQueue<char>* queue_;
111118
};
112119

120+
void AsyncGRPCServer::WaitClientGet(int count) {
121+
for (int i = 0; i < count; ++i) {
122+
var_get_queue_.Pop();
123+
}
124+
}
125+
113126
void AsyncGRPCServer::RunSyncUpdate() {
114127
grpc::ServerBuilder builder;
115128
builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
@@ -149,7 +162,6 @@ void AsyncGRPCServer::ShutdownQueue() {
149162
}
150163

151164
// This URL explains why shutdown is complicate:
152-
// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
153165
void AsyncGRPCServer::ShutDown() {
154166
server_->Shutdown();
155167
ShutdownQueue();
@@ -170,10 +182,12 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
170182
if (is_shut_down_) {
171183
return;
172184
}
173-
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_);
185+
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
186+
&var_get_queue_);
174187
VLOG(4) << "create Requestget status:" << get->Status();
175188
}
176189

190+
// FIXME(typhoonzero): remove wait argument and change cq_name to enum.
177191
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
178192
std::string cq_name,
179193
std::function<void()> TryToRegisterNewOne) {
@@ -188,9 +202,9 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
188202
}
189203

190204
PADDLE_ENFORCE(tag);
191-
if (wait && !done_) {
192-
Wait();
193-
}
205+
// FIXME(typhoonzero): de-couple the barriers with recv_op
206+
if (cq_name == "cq_get") WaitCond(1);
207+
if (cq_name == "cq_send") WaitCond(0);
194208

195209
RequestBase* base = (RequestBase*)tag;
196210
// reference:
@@ -222,22 +236,18 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
222236
}
223237
}
224238

225-
void AsyncGRPCServer::Wait() {
226-
std::unique_lock<std::mutex> lock(this->mutex_);
227-
condition_.wait(lock, [=] { return this->done_ == true; });
228-
}
229-
230-
void AsyncGRPCServer::Reset() {
231-
std::lock_guard<std::mutex> lock(this->mutex_);
232-
done_ = false;
239+
void AsyncGRPCServer::WaitCond(int cond) {
240+
std::unique_lock<std::mutex> lock(this->barrier_mutex_);
241+
barrier_condition_.wait(lock,
242+
[=] { return this->barrier_cond_step_ == cond; });
233243
}
234244

235-
void AsyncGRPCServer::Done() {
245+
void AsyncGRPCServer::SetCond(int cond) {
236246
{
237-
std::lock_guard<std::mutex> lock(this->mutex_);
238-
done_ = true;
247+
std::lock_guard<std::mutex> lock(this->barrier_mutex_);
248+
barrier_cond_step_ = cond;
239249
}
240-
condition_.notify_all();
250+
barrier_condition_.notify_all();
241251
}
242252

243253
} // namespace detail

paddle/operators/detail/grpc_server.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
4141

4242
void RunSyncUpdate();
4343

44-
void Reset();
45-
46-
void Done();
44+
// functions to sync server barrier status.
45+
void WaitCond(int cond);
46+
void SetCond(int cond);
47+
void WaitClientGet(int count);
4748

4849
void SetScope(framework::Scope *scope) { scope_ = scope; }
4950

@@ -56,7 +57,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
5657
void ShutDown();
5758

5859
protected:
59-
void Wait();
6060
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
6161
std::string cq_name,
6262
std::function<void()> TryToRegisterNewOne);
@@ -78,11 +78,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
7878
const platform::DeviceContext *dev_ctx_;
7979
// received variable from RPC, operators fetch variable from this queue.
8080
SimpleBlockQueue<MessageWithName> var_recv_queue_;
81+
SimpleBlockQueue<char> var_get_queue_;
8182

8283
// condition of the sub program
83-
std::mutex mutex_;
84-
volatile mutable bool done_;
85-
std::condition_variable condition_;
84+
std::mutex barrier_mutex_;
85+
mutable int barrier_cond_step_;
86+
std::condition_variable barrier_condition_;
8687

8788
std::unique_ptr<std::thread> t_send_;
8889
std::unique_ptr<std::thread> t_get_;

paddle/operators/recv_op.cc

Lines changed: 30 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,17 @@ limitations under the License. */
2727
#include "paddle/operators/detail/grpc_server.h"
2828
#include "paddle/operators/detail/sendrecvop_utils.h"
2929
#include "paddle/operators/detail/simple_block_queue.h"
30+
#include "paddle/string/printf.h"
3031

3132
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
3233

3334
namespace paddle {
3435
namespace operators {
3536

37+
constexpr int kCondStart = 0;
38+
constexpr int kCondRunning = 1;
39+
constexpr int kCondDone = 2;
40+
3641
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
3742
service->RunSyncUpdate();
3843
VLOG(4) << "RunServer thread end";
@@ -77,42 +82,41 @@ class RecvOp : public framework::OperatorBase {
7782
if (grads_counter_.find(varname) == grads_counter_.end()) {
7883
grads_counter_[varname] = 0;
7984
}
80-
char ret[256];
81-
snprintf(ret, sizeof(ret), "%s.trainer_%d", varname.c_str(),
82-
grads_counter_[varname]++);
83-
return std::string(ret);
85+
return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
8486
}
8587

8688
void Run(const framework::Scope &scope,
8789
const platform::Place &dev_place) const override {
88-
// FIXME(typhoonzero): no new scopes for every run.
89-
framework::Scope &recv_scope = scope.NewScope();
9090
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
9191
auto &dev_ctx = *pool.Get(dev_place);
92+
framework::Scope &recv_scope = scope.NewScope();
9293

9394
// FIXME(Yancey1989): initialize rpc server with laze mode.
9495
rpc_service_->SetScope(&recv_scope);
9596
rpc_service_->SetDevCtx(&dev_ctx);
9697
auto param_list = Attr<std::vector<std::string>>("ParamList");
9798
auto grad_list = Attr<std::vector<std::string>>("GradList");
98-
auto trainer_count = Attr<int>("Trainers");
99+
auto fan_in = Attr<int>("Fanin");
99100
size_t param_count = param_list.size();
100101

101-
rpc_service_->Reset();
102+
std::string program_str = Attr<std::string>("OptimizeProgram");
103+
framework::proto::ProgramDesc program_desc;
104+
program_desc.ParseFromString(program_str);
105+
framework::ProgramDesc program(program_desc);
106+
framework::Executor executor(dev_place);
107+
102108
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
103109
bool exit_flag = false;
104-
VLOG(4) << "param_count:" << param_count
105-
<< " trainer_count:" << trainer_count;
110+
int64_t barrier_size = param_count * fan_in;
106111
while (!exit_flag) {
107-
// TODO(gognwb): simply this loop.
108-
// Get from multiple trainers, we don't care about order in which
109-
// the gradient arrives, just add suffix 0~n then average the gradient.
110-
for (size_t i = 0; i < param_count * trainer_count; ++i) {
111-
// blocking get one var from client.
112+
// Get from multiple trainers, we don't care about the order in which
113+
// the gradients arrives, just add suffix 0~n and merge the gradient.
114+
rpc_service_->SetCond(0);
115+
for (size_t i = 0; i < barrier_size; ++i) {
112116
const detail::MessageWithName &v = rpc_service_->Get();
113117
auto grad_var_name = v.first;
114118
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
115-
VLOG(4) << "received LISTEN_TERMINATE_MESSAGE and RunOp.Run() exit";
119+
LOG(INFO) << "received terminate message and exit";
116120
exit_flag = true;
117121
break;
118122
}
@@ -121,49 +125,31 @@ class RecvOp : public framework::OperatorBase {
121125
if (it != grad_list.end()) {
122126
param_var_name = param_list[it - grad_list.begin()];
123127
} else {
124-
LOG(ERROR) << "grad have no paired param found!\"" << grad_var_name
125-
<< "\"";
128+
LOG(ERROR) << "grad have no paired param:" << grad_var_name;
126129
}
127130
VLOG(3) << "recved grad: " << grad_var_name
128131
<< " updating param: " << param_var_name;
129-
130-
auto *merged_grad = recv_scope.FindVar(grad_var_name);
131-
if (merged_grad == nullptr) {
132-
auto *ptr = recv_scope.Var(grad_var_name);
133-
CreateTensorFromMessageType(ptr, v.second.type());
134-
VLOG(3) << "Create Variable " << grad_var_name
135-
<< " on recv scope, which pointer is " << ptr << " type is "
136-
<< v.second.type();
137-
}
138-
139-
if (trainer_count > 1) {
132+
if (fan_in > 1) {
140133
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
141134
}
142-
143-
auto *var = recv_scope.Var(grad_var_name);
135+
auto *var = recv_scope.FindVar(grad_var_name);
136+
if (var == nullptr) {
137+
LOG(ERROR) << "can not find server side var: " << grad_var_name;
138+
PADDLE_THROW("can not find server side var");
139+
}
144140
detail::DeserializeFromMessage(v.second, dev_ctx, var);
145141
}
146-
147142
if (exit_flag) {
148143
break;
149144
}
150-
151-
rpc_service_->Reset();
152-
153-
std::string program_str = Attr<std::string>("OptimizeProgram");
154-
framework::proto::ProgramDesc program_desc;
155-
program_desc.ParseFromString(program_str);
156-
framework::ProgramDesc program(program_desc);
157-
framework::Executor executor(dev_place);
158-
// Run sub graph to get optimized tensor
159145
try {
160146
executor.Run(program, &recv_scope, 0, /*global_block*/
161147
false /*create_local_scope*/, false /*create_vars*/);
162148
} catch (std::exception &e) {
163149
LOG(ERROR) << "run sub program error " << e.what();
164150
}
165-
166-
rpc_service_->Done();
151+
rpc_service_->SetCond(1);
152+
rpc_service_->WaitClientGet(barrier_size);
167153
grads_counter_.clear();
168154
} // while(true)
169155
}
@@ -199,7 +185,7 @@ This operator will recv tensor from send_op
199185
"GradList", "type list of string",
200186
"grad->param name mapping to find which param to optimize.")
201187
.SetDefault({});
202-
AddAttr<int>("Trainers", "type int",
188+
AddAttr<int>("Fanin", "type int",
203189
"Number of trainers in the current cluster job")
204190
.SetDefault(1);
205191
}

paddle/operators/send_op.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,13 @@ class SendOp : public framework::OperatorBase {
4141
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
4242
auto& ctx = *pool.Get(place);
4343
for (size_t i = 0; i < ins.size(); i++) {
44+
VLOG(3) << "sending " << ins[i];
4445
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
4546
}
47+
PADDLE_ENFORCE(client_.Wait());
4648

4749
for (size_t i = 0; i < outs.size(); i++) {
50+
VLOG(3) << "getting " << outs[i];
4851
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
4952
}
5053

python/paddle/v2/fluid/distribute_transpiler.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,19 @@ def get_pserver_program(self, endpoint):
420420
pserver_program = Program()
421421
for v in self.param_grad_ep_mapping[endpoint]["params"]:
422422
self._clone_var(pserver_program.global_block(), v)
423+
for v in self.param_grad_ep_mapping[endpoint]["grads"]:
424+
# create vars for each trainer in global scope, so
425+
# we don't need to create them when grad arrives.
426+
pserver_program.global_block().create_var(
427+
name=v.name, persistable=True, dtype=v.dtype, shape=v.shape)
428+
for trainer_id in xrange(self.trainers):
429+
print("create variable for program: %s.trainer_%d" %
430+
(v.name, trainer_id))
431+
pserver_program.global_block().create_var(
432+
name="%s.trainer_%d" % (v.name, trainer_id),
433+
persistable=True,
434+
dtype=v.dtype,
435+
shape=v.shape)
423436
# step6
424437
optimize_sub_program = Program()
425438
for idx, opt_op in enumerate(self.optimize_ops):
@@ -449,7 +462,7 @@ def get_pserver_program(self, endpoint):
449462
p.name
450463
for p in self.param_grad_ep_mapping[endpoint]["grads"]
451464
],
452-
"Trainers": self.trainers
465+
"Fanin": self.trainers
453466
})
454467
pserver_program.sync_with_cpp()
455468
return pserver_program

python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,26 +52,27 @@
5252
place = fluid.CPUPlace()
5353
exe = fluid.Executor(place)
5454

55-
t = fluid.DistributeTranspiler()
56-
# all parameter server endpoints list for spliting parameters
57-
pserver_endpoints = os.getenv("PSERVERS")
58-
# server endpoint for current node
59-
current_endpoint = os.getenv("SERVER_ENDPOINT")
60-
# run as trainer or parameter server
55+
pserver_endpoints = os.getenv("PSERVERS") # all pserver endpoints
56+
trainers = int(os.getenv("TRAINERS")) # total trainer count
57+
current_endpoint = os.getenv("SERVER_ENDPOINT") # current pserver endpoint
6158
training_role = os.getenv("TRAINING_ROLE",
6259
"TRAINER") # get the training role: trainer/pserver
63-
t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
60+
t = fluid.DistributeTranspiler()
61+
t.transpile(
62+
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=trainers)
6463

6564
if training_role == "PSERVER":
6665
if not current_endpoint:
6766
print("need env SERVER_ENDPOINT")
6867
exit(1)
6968
pserver_prog = t.get_pserver_program(current_endpoint)
70-
exe.run(fluid.default_startup_program())
69+
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
70+
exe.run(pserver_startup)
7171
exe.run(pserver_prog)
7272
elif training_role == "TRAINER":
7373
trainer_prog = t.get_trainer_program()
7474
feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
75+
# TODO(typhoonzero): change trainer startup program to fetch parameters from pserver
7576
exe.run(fluid.default_startup_program())
7677

7778
for pass_id in range(PASS_NUM):

0 commit comments

Comments
 (0)