Skip to content

Commit f7fd711

Browse files
author
Yancey
authored
Merge pull request #11868 from Yancey1989/dist_pass_barrier
add dist pass barrier
2 parents 27d6962 + 37410a0 commit f7fd711

File tree

11 files changed

+124
-44
lines changed

11 files changed

+124
-44
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,16 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
4646
Executor::Executor(const platform::Place& place) : place_(place) {}
4747

4848
#ifdef PADDLE_WITH_DISTRIBUTE
49-
void Executor::Complete() {
50-
::paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>()
51-
->SendComplete();
49+
void Executor::BeginPass() {
50+
::paddle::operators::distributed::RPCClient::GetInstance<
51+
::paddle::operators::distributed::GRPCClient>()
52+
->SendBeginPass();
53+
}
54+
55+
void Executor::EndPass() {
56+
::paddle::operators::distributed::RPCClient::GetInstance<
57+
::paddle::operators::distributed::GRPCClient>()
58+
->SendEndPass();
5259
}
5360
#endif
5461

paddle/fluid/framework/executor.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,14 @@ class Executor {
4646

4747
#ifdef PADDLE_WITH_DISTRIBUTE
4848
/*
49-
* Sending signal to pserver to mark current trainer stop.
49+
* Sending signal to pserver to mark current pass started.
5050
*/
51-
void Complete();
51+
void BeginPass();
52+
53+
/*
54+
* Sending signal to pserver to mark current pass finished.
55+
*/
56+
void EndPass();
5257
#endif
5358

5459
/* @Brief

paddle/fluid/operators/distributed/grpc_client.cc

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,20 @@ void GRPCClient::InitEventLoop() {
3535
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
3636
}
3737

38-
void GRPCClient::SendComplete() {
38+
void GRPCClient::SendBeginPass() {
3939
for (auto& it : channels_) {
40-
this->AsyncSendComplete(it.first);
40+
VLOG(3) << "send begin pass to: " << it.first;
41+
this->AsyncSendBeginPass(it.first);
4142
}
43+
this->Wait();
44+
}
45+
46+
void GRPCClient::SendEndPass() {
47+
for (auto& it : channels_) {
48+
VLOG(3) << "send end pass to " << it.first;
49+
this->AsyncSendEndPass(it.first);
50+
}
51+
this->Wait();
4252
}
4353

4454
GRPCClient::~GRPCClient() {
@@ -226,19 +236,32 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
226236
req_count_++;
227237
}
228238

229-
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
239+
void GRPCClient::AsyncSendBeginPass(const std::string& ep, int64_t time_out) {
230240
const auto ch = GetChannel(ep);
231241

232242
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
233243
s->Prepare(time_out);
234244

235245
sendrecv::VariableMessage req;
236-
req.set_varname(COMPLETE_MESSAGE);
246+
req.set_varname(BEGIN_PASS_MESSAGE);
237247
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
238248
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
239249
req_count_++;
240250
}
241251

252+
void GRPCClient::AsyncSendEndPass(const std::string& ep, int64_t time_out) {
253+
const auto ch = GetChannel(ep);
254+
255+
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
256+
s->Prepare(time_out);
257+
258+
sendrecv::VariableMessage req;
259+
req.set_varname(END_PASS_MESSAGE);
260+
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
261+
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
262+
req_count_++;
263+
}
264+
242265
void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
243266
const std::string& dir,
244267
int64_t time_out) {

paddle/fluid/operators/distributed/grpc_client.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,12 @@ class BaseProcessor {
7777
context_.reset(new grpc::ClientContext());
7878
var_h_ = var_info;
7979
context_->set_wait_for_ready(true);
80-
81-
std::chrono::system_clock::time_point deadline =
82-
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
83-
84-
context_->set_deadline(deadline);
80+
if (time_out) {
81+
std::chrono::system_clock::time_point deadline =
82+
std::chrono::system_clock::now() +
83+
std::chrono::milliseconds(time_out);
84+
context_->set_deadline(deadline);
85+
}
8586
}
8687

8788
virtual void Prepare(int64_t time_out) {
@@ -214,9 +215,17 @@ class GRPCClient : public RPCClient {
214215
void AsyncCheckpointNotify(const std::string& ep, const std::string& dir,
215216
int64_t time_out = FLAGS_rpc_deadline) override;
216217

218+
void AsyncSendBeginPass(const std::string& ep,
219+
int64_t time_out = FLAGS_rpc_deadline) override;
220+
221+
void AsyncSendEndPass(const std::string& ep,
222+
int64_t time_out = FLAGS_rpc_deadline) override;
223+
217224
void Wait() override;
218225

219-
void SendComplete() override;
226+
void SendBeginPass() override;
227+
228+
void SendEndPass() override;
220229

221230
protected:
222231
void InitImpl() override;
@@ -227,9 +236,6 @@ class GRPCClient : public RPCClient {
227236

228237
void Proceed();
229238

230-
void AsyncSendComplete(const std::string& ep,
231-
int64_t time_out = FLAGS_rpc_deadline);
232-
233239
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
234240

235241
private:

paddle/fluid/operators/distributed/request_handler.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,14 @@ constexpr char kRequestSend[] = "RequestSend";
3737
constexpr char kRequestGet[] = "RequestGet";
3838
constexpr char kRequestPrefetch[] = "RequestPrefetch";
3939
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
40+
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
4041

4142
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
4243
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
4344
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
4445
#define COMPLETE_MESSAGE "COMPLETE@RECV"
46+
#define BEGIN_PASS_MESSAGE "BEGIN_PASS@RECV"
47+
#define END_PASS_MESSAGE "END_PASS@RECV"
4548

4649
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
4750
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"

paddle/fluid/operators/distributed/request_handler_impl.cc

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,14 @@ bool RequestSendHandler::Handle(const std::string& varname,
5555
if (varname == BATCH_BARRIER_MESSAGE) {
5656
VLOG(3) << "sync: recv batch barrier message";
5757
rpc_server_->IncreaseBatchBarrier(kRequestSend);
58-
} else if (varname == COMPLETE_MESSAGE) {
59-
VLOG(3) << "sync: recv complete message";
60-
rpc_server_->DecreaseClientNum();
58+
} else if (varname == BEGIN_PASS_MESSAGE) {
59+
VLOG(3) << "sync: recv begin pass message";
60+
rpc_server_->WaitCond(kRequestSend);
61+
rpc_server_->BeginPass();
6162
} else {
6263
VLOG(3) << "sync: received var_name: " << varname;
63-
if (sync_mode_) {
64-
rpc_server_->WaitCond(kRequestSend);
65-
}
64+
rpc_server_->WaitCond(kRequestSend);
65+
VLOG(3) << "sync: processing received var: " << varname;
6666

6767
if (invar == nullptr) {
6868
LOG(ERROR) << "sync: Can not find server side var: " << varname;
@@ -91,21 +91,21 @@ bool RequestGetHandler::Handle(const std::string& varname,
9191
framework::Variable** outvar,
9292
const std::string& out_var_name) {
9393
VLOG(4) << "RequestGetHandler:" << varname;
94-
95-
if (varname != FETCH_BARRIER_MESSAGE) {
96-
if (sync_mode_) {
94+
if (sync_mode_) {
95+
if (varname == FETCH_BARRIER_MESSAGE) {
96+
VLOG(3) << "sync: recv fetch barrier message";
97+
rpc_server_->IncreaseBatchBarrier(kRequestGet);
98+
} else if (varname == END_PASS_MESSAGE) {
99+
rpc_server_->EndPass();
100+
} else {
97101
rpc_server_->WaitCond(kRequestGet);
102+
*outvar = scope_->FindVar(varname);
103+
}
104+
} else {
105+
if (varname != FETCH_BARRIER_MESSAGE && varname != END_PASS_MESSAGE) {
106+
*outvar = scope_->FindVar(varname);
98107
}
99-
*outvar = scope_->FindVar(varname);
100-
return true;
101-
}
102-
103-
// FETCH_BARRIER_MESSAGE
104-
if (sync_mode_) {
105-
VLOG(3) << "sync: recv fetch barrier message";
106-
rpc_server_->IncreaseBatchBarrier(kRequestGet);
107108
}
108-
109109
return true;
110110
}
111111

paddle/fluid/operators/distributed/rpc_client.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,17 @@ class RPCClient {
6060
const std::string& dir,
6161
int64_t time_out = FLAGS_rpc_deadline) = 0;
6262

63-
// SendComplete tells all the server that current trainer have no more data
64-
// to train, so that the pserver can reduce it's barrier count, and continue
65-
// to train with other trainers.
66-
virtual void SendComplete() = 0;
63+
virtual void AsyncSendBeginPass(const std::string& ep,
64+
int64_t time_out = FLAGS_rpc_deadline) = 0;
65+
66+
virtual void AsyncSendEndPass(const std::string& ep,
67+
int64_t time_out = FLAGS_rpc_deadline) = 0;
68+
69+
// BeginePass/EndPass tells all the pserver that start/end a pass, so that
70+
// the pserver can increase/reduce it's barrier count, and continue to train
71+
// with other trainers.
72+
virtual void SendBeginPass() = 0;
73+
virtual void SendEndPass() = 0;
6774

6875
virtual void Wait() = 0;
6976

paddle/fluid/operators/distributed/rpc_server.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ void RPCServer::SavePort() const {
4444
void RPCServer::WaitBarrier(const std::string& rpc_name) {
4545
std::unique_lock<std::mutex> lock(this->mutex_);
4646
barrier_cond_.wait(lock, [this, &rpc_name] {
47-
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
47+
return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) ||
48+
exit_flag_.load());
4849
});
4950

5051
VLOG(3) << "batch_barrier_: " << rpc_name << " "
@@ -63,10 +64,25 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
6364
}
6465
}
6566

66-
void RPCServer::DecreaseClientNum() {
67+
void RPCServer::BeginPass() {
68+
VLOG(4) << "RPCServer begin increase pass barrier";
69+
{
70+
std::unique_lock<std::mutex> lock(mutex_);
71+
client_num_++;
72+
VLOG(4) << "increase client_num to: " << client_num_;
73+
}
74+
barrier_cond_.notify_all();
75+
}
76+
77+
void RPCServer::EndPass() {
78+
VLOG(4) << "RPCServer begin increase pass barrier";
6779
{
6880
std::unique_lock<std::mutex> lock(mutex_);
6981
client_num_--;
82+
VLOG(4) << "decrease client_num to: " << client_num_;
83+
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
84+
barrier_counter_[kRequestGet]--;
85+
}
7086
}
7187
barrier_cond_.notify_all();
7288
}

paddle/fluid/operators/distributed/rpc_server.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class RPCServer {
4343
bool IsExit() { return exit_flag_.load(); }
4444

4545
int GetSelectedPort() const { return selected_port_; }
46+
47+
int GetClientNum() const;
48+
4649
void SavePort() const;
4750

4851
// RegisterRPC, register the rpc method name to a handler
@@ -60,7 +63,10 @@ class RPCServer {
6063
void SetCond(const std::string& rpc_name);
6164
void WaitCond(const std::string& rpc_name);
6265
void IncreaseBatchBarrier(const std::string rpc_name);
63-
void DecreaseClientNum();
66+
67+
void BeginPass();
68+
void EndPass();
69+
6470
void ResetBarrierCounter();
6571

6672
protected:

paddle/fluid/pybind/pybind.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,8 @@ All parameter, weight, gradient are variables in Paddle.
493493
py::class_<framework::Executor>(m, "Executor")
494494
.def(py::init<const platform::Place &>())
495495
#ifdef PADDLE_WITH_DISTRIBUTE
496-
.def("complete", &Executor::Complete)
496+
.def("begin_pass", &Executor::BeginPass)
497+
.def("end_pass", &Executor::EndPass)
497498
#endif
498499
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
499500
int block_id, bool create_local_scope, bool create_vars) {

0 commit comments

Comments
 (0)