Skip to content

Commit 1366832

Browse files
committed
add dist pass barrier
1 parent 5988d0c commit 1366832

File tree

12 files changed

+128
-47
lines changed

12 files changed

+128
-47
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,20 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
4848
Executor::Executor(const platform::Place& place) : place_(place) {}
4949

5050
#ifdef PADDLE_WITH_DISTRIBUTE
51-
void Executor::Complete() {
52-
::paddle::operators::distributed::RPCClient::GetInstance<
53-
::paddle::operators::distributed::GRPCClient>()
54-
->SendComplete();
51+
void Executor::BeginPass() {
52+
auto client = ::paddle::operators::distributed::RPCClient::GetInstance<
53+
::paddle::operators::distributed::GRPCClient>();
54+
55+
client->SendBeginPass();
56+
client->Wait();
57+
}
58+
59+
void Executor::EndPass() {
60+
auto client = ::paddle::operators::distributed::RPCClient::GetInstance<
61+
::paddle::operators::distributed::GRPCClient>();
62+
63+
client->SendEndPass();
64+
client->Wait();
5565
}
5666
#endif
5767

paddle/fluid/framework/executor.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,14 @@ class Executor {
4646

4747
#ifdef PADDLE_WITH_DISTRIBUTE
4848
/*
49-
* Sending signal to pserver to mark current trainer stop.
49+
* Sending signal to pserver to mark current pass started.
5050
*/
51-
void Complete();
51+
void BeginPass();
52+
53+
/*
54+
* Sending signal to pserver to mark current pass finished.
55+
*/
56+
void EndPass();
5257
#endif
5358

5459
/* @Brief

paddle/fluid/operators/distributed/grpc_client.cc

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,17 @@ void GRPCClient::InitEventLoop() {
3535
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
3636
}
3737

38-
void GRPCClient::SendComplete() {
38+
void GRPCClient::SendBeginPass() {
3939
for (auto& it : channels_) {
40-
this->AsyncSendComplete(it.first);
40+
VLOG(3) << "send begin pass to: " it.first;
41+
this->AsyncSendBeginPass(it.first);
42+
}
43+
}
44+
45+
void GRPCClient::SendEndPass() {
46+
for (auto& it : channels_) {
47+
VLOG(3) << "send end pass to " << it.first;
48+
this->AsyncSendEndPass(it.first);
4149
}
4250
}
4351

@@ -226,19 +234,32 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
226234
req_count_++;
227235
}
228236

229-
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
237+
void GRPCClient::AsyncSendBeginPass(const std::string& ep, int64_t time_out) {
230238
const auto ch = GetChannel(ep);
231239

232240
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
233241
s->Prepare(time_out);
234242

235243
sendrecv::VariableMessage req;
236-
req.set_varname(COMPLETE_MESSAGE);
244+
req.set_varname(BEGIN_PASS_MESSAGE);
237245
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
238246
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
239247
req_count_++;
240248
}
241249

250+
void GRPCClient::AsyncSendEndPass(const std::string& ep, int64_t time_out) {
251+
const auto ch = GetChannel(ep);
252+
253+
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
254+
s->Prepare(time_out);
255+
256+
sendrecv::VariableMessage req;
257+
req.set_varname(END_PASS_MESSAGE);
258+
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
259+
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
260+
req_count_++;
261+
}
262+
242263
void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
243264
const std::string& dir,
244265
int64_t time_out) {

paddle/fluid/operators/distributed/grpc_client.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,12 @@ class BaseProcessor {
7777
context_.reset(new grpc::ClientContext());
7878
var_h_ = var_info;
7979
context_->set_wait_for_ready(true);
80-
81-
std::chrono::system_clock::time_point deadline =
82-
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
83-
84-
context_->set_deadline(deadline);
80+
if (time_out) {
81+
std::chrono::system_clock::time_point deadline =
82+
std::chrono::system_clock::now() +
83+
std::chrono::milliseconds(time_out);
84+
context_->set_deadline(deadline);
85+
}
8586
}
8687

8788
virtual void Prepare(int64_t time_out) {
@@ -214,9 +215,17 @@ class GRPCClient : public RPCClient {
214215
void AsyncCheckpointNotify(const std::string& ep, const std::string& dir,
215216
int64_t time_out = FLAGS_rpc_deadline) override;
216217

218+
void AsyncSendBeginPass(const std::string& ep,
219+
int64_t time_out = FLAGS_rpc_deadline) override;
220+
221+
void AsyncSendEndPass(const std::string& ep,
222+
int64_t time_out = FLAGS_rpc_deadline) override;
223+
217224
void Wait() override;
218225

219-
void SendComplete() override;
226+
void SendBeginPass() override;
227+
228+
void SendEndPass() override;
220229

221230
protected:
222231
void InitImpl() override;
@@ -227,9 +236,6 @@ class GRPCClient : public RPCClient {
227236

228237
void Proceed();
229238

230-
void AsyncSendComplete(const std::string& ep,
231-
int64_t time_out = FLAGS_rpc_deadline);
232-
233239
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
234240

235241
private:

paddle/fluid/operators/distributed/request_handler.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,14 @@ constexpr char kRequestSend[] = "RequestSend";
3737
constexpr char kRequestGet[] = "RequestGet";
3838
constexpr char kRequestPrefetch[] = "RequestPrefetch";
3939
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
40+
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
4041

4142
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
4243
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
4344
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
4445
#define COMPLETE_MESSAGE "COMPLETE@RECV"
46+
#define BEGIN_PASS_MESSAGE "BEGIN_PASS@RECV"
47+
#define END_PASS_MESSAGE "END_PASS@RECV"
4548

4649
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
4750
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"

paddle/fluid/operators/distributed/request_handler_impl.cc

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,14 @@ bool RequestSendHandler::Handle(const std::string& varname,
5555
if (varname == BATCH_BARRIER_MESSAGE) {
5656
VLOG(3) << "sync: recv batch barrier message";
5757
rpc_server_->IncreaseBatchBarrier(kRequestSend);
58-
} else if (varname == COMPLETE_MESSAGE) {
59-
VLOG(3) << "sync: recv complete message";
60-
rpc_server_->DecreaseClientNum();
58+
} else if (varname == BEGIN_PASS_MESSAGE) {
59+
VLOG(3) << "sync: recv begin pass message";
60+
rpc_server_->WaitCond(kRequestSend);
61+
rpc_server_->BeginPass();
6162
} else {
6263
VLOG(3) << "sync: received var_name: " << varname;
63-
if (sync_mode_) {
64-
rpc_server_->WaitCond(kRequestSend);
65-
}
64+
rpc_server_->WaitCond(kRequestSend);
65+
VLOG(3) << "sync: processing received var: " << varname;
6666

6767
if (invar == nullptr) {
6868
LOG(ERROR) << "sync: Can not find server side var: " << varname;
@@ -91,21 +91,21 @@ bool RequestGetHandler::Handle(const std::string& varname,
9191
framework::Variable** outvar,
9292
const std::string& out_var_name) {
9393
VLOG(4) << "RequestGetHandler:" << varname;
94-
95-
if (varname != FETCH_BARRIER_MESSAGE) {
96-
if (sync_mode_) {
94+
if (sync_mode_) {
95+
if (varname == FETCH_BARRIER_MESSAGE) {
96+
VLOG(3) << "sync: recv fetch barrier message";
97+
rpc_server_->IncreaseBatchBarrier(kRequestGet);
98+
} else if (varname == END_PASS_MESSAGE) {
99+
rpc_server_->EndPass();
100+
} else {
97101
rpc_server_->WaitCond(kRequestGet);
102+
*outvar = scope_->FindVar(varname);
103+
}
104+
} else {
105+
if (varname != FETCH_BARRIER_MESSAGE && varname != END_PASS_MESSAGE) {
106+
*outvar = scope_->FindVar(varname);
98107
}
99-
*outvar = scope_->FindVar(varname);
100-
return true;
101-
}
102-
103-
// FETCH_BARRIER_MESSAGE
104-
if (sync_mode_) {
105-
VLOG(3) << "sync: recv fetch barrier message";
106-
rpc_server_->IncreaseBatchBarrier(kRequestGet);
107108
}
108-
109109
return true;
110110
}
111111

paddle/fluid/operators/distributed/rpc_client.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include "gflags/gflags.h"
1717

1818
// default to 3min to avoid temprary network failures.
19-
DEFINE_int32(rpc_deadline, 180000, "deadline timeouts for rpc");
19+
DEFINE_int32(rpc_deadline, 30000, "deadline timeouts for rpc");
2020

2121
namespace paddle {
2222
namespace operators {

paddle/fluid/operators/distributed/rpc_client.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,17 @@ class RPCClient {
6060
const std::string& dir,
6161
int64_t time_out = FLAGS_rpc_deadline) = 0;
6262

63-
// SendComplete tells all the server that current trainer have no more data
64-
// to train, so that the pserver can reduce it's barrier count, and continue
65-
// to train with other trainers.
66-
virtual void SendComplete() = 0;
63+
virtual void AsyncSendBeginPass(const std::string& ep,
64+
int64_t time_out = FLAGS_rpc_deadline) = 0;
65+
66+
virtual void AsyncSendEndPass(const std::string& ep,
67+
int64_t time_out = FLAGS_rpc_deadline) = 0;
68+
69+
// BeginePass/EndPass tells all the pserver that start/end a pass, so that
70+
// the pserver can increase/reduce it's barrier count, and continue to train
71+
// with other trainers.
72+
virtual void SendBeginPass() = 0;
73+
virtual void SendEndPass() = 0;
6774

6875
virtual void Wait() = 0;
6976

paddle/fluid/operators/distributed/rpc_server.cc

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ void RPCServer::SavePort() const {
4444
void RPCServer::WaitBarrier(const std::string& rpc_name) {
4545
std::unique_lock<std::mutex> lock(this->mutex_);
4646
barrier_cond_.wait(lock, [this, &rpc_name] {
47-
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
47+
return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) ||
48+
exit_flag_.load());
4849
});
4950

5051
VLOG(3) << "batch_barrier_: " << rpc_name << " "
@@ -63,10 +64,25 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
6364
}
6465
}
6566

66-
void RPCServer::DecreaseClientNum() {
67+
void RPCServer::BeginPass() {
68+
VLOG(4) << "RPCServer begin increase pass barrier";
6769
{
68-
std::unique_lock<std::mutex> lock(mutex_);
70+
std::unique_lock<std::mutex> locl(mutex_);
71+
client_num_++;
72+
VLOG(4) << "increase client_num to: " << client_num_;
73+
}
74+
barrier_cond_.notify_all();
75+
}
76+
77+
void RPCServer::EndPass() {
78+
VLOG(4) << "RPCServer begin increase pass barrier";
79+
{
80+
std::unique_lock<std::mutex> locl(mutex_);
6981
client_num_--;
82+
VLOG(4) << "decrease client_num to: " << client_num_;
83+
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
84+
barrier_counter_[kRequestGet]--;
85+
}
7086
}
7187
barrier_cond_.notify_all();
7288
}

paddle/fluid/operators/distributed/rpc_server.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class RPCServer {
4343
bool IsExit() { return exit_flag_.load(); }
4444

4545
int GetSelectedPort() const { return selected_port_; }
46+
47+
int GetClientNum() const;
48+
4649
void SavePort() const;
4750

4851
// RegisterRPC, register the rpc method name to a handler
@@ -60,7 +63,10 @@ class RPCServer {
6063
void SetCond(const std::string& rpc_name);
6164
void WaitCond(const std::string& rpc_name);
6265
void IncreaseBatchBarrier(const std::string rpc_name);
63-
void DecreaseClientNum();
66+
67+
void BeginPass();
68+
void EndPass();
69+
6470
void ResetBarrierCounter();
6571

6672
protected:

0 commit comments

Comments
 (0)