Skip to content

Commit 6133efd

Browse files
author
Yancey
authored
Merge pull request #12218 from Yancey1989/rpc_complete_interface
Add rpc complete interface
2 parents 24bea40 + fb06ed7 commit 6133efd

File tree

14 files changed

+91
-103
lines changed

14 files changed

+91
-103
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ paddle.fluid.program_guard ArgSpec(args=[], varargs='args', keywords='kwds', def
3535
paddle.fluid.get_var ArgSpec(args=['name', 'program'], varargs=None, keywords=None, defaults=(None,))
3636
paddle.fluid.Executor.__init__ ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None)
3737
paddle.fluid.Executor.as_lodtensor ArgSpec(args=['self', 'data'], varargs=None, keywords=None, defaults=None)
38-
paddle.fluid.Executor.begin_pass ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
39-
paddle.fluid.Executor.end_pass ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
38+
paddle.fluid.Executor.close ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
4039
paddle.fluid.Executor.run ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False))
4140
paddle.fluid.global_scope ArgSpec(args=[], varargs=None, keywords=None, defaults=None)
4241
paddle.fluid.scope_guard ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)

paddle/fluid/framework/executor.cc

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,13 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
4545

4646
Executor::Executor(const platform::Place& place) : place_(place) {}
4747

48+
void Executor::Close() {
4849
#ifdef PADDLE_WITH_DISTRIBUTE
49-
void Executor::BeginPass() {
5050
::paddle::operators::distributed::RPCClient::GetInstance<
5151
::paddle::operators::distributed::GRPCClient>()
52-
->SendBeginPass();
53-
}
54-
55-
void Executor::EndPass() {
56-
::paddle::operators::distributed::RPCClient::GetInstance<
57-
::paddle::operators::distributed::GRPCClient>()
58-
->SendEndPass();
59-
}
52+
->SendComplete();
6053
#endif
54+
}
6155

6256
void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
6357
if (var_type == proto::VarType::LOD_TENSOR) {

paddle/fluid/framework/executor.h

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,11 @@ class Executor {
4444

4545
explicit Executor(const platform::Place& place);
4646

47-
#ifdef PADDLE_WITH_DISTRIBUTE
4847
/*
49-
* Sending signal to pserver to mark current pass started.
48+
* Close this Executor.
49+
* Calling this method will send complete messages to all pserver instances.
5050
*/
51-
void BeginPass();
52-
53-
/*
54-
* Sending signal to pserver to mark current pass finished.
55-
*/
56-
void EndPass();
57-
#endif
51+
void Close();
5852

5953
/* @Brief
6054
* Runtime evaluation of the given ProgramDesc under certain Scope

paddle/fluid/operators/distributed/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ if(WITH_GRPC)
1818
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
1919
cc_test(grpc_serde_test SRCS grpc_serde_test.cc
2020
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
21-
cc_test(grpc_server_test SRCS rpc_server_test.cc
21+
cc_test(rpc_server_test SRCS rpc_server_test.cc
2222
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL)
2323
return()
2424
endif()

paddle/fluid/operators/distributed/grpc_client.cc

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,16 @@ void GRPCClient::InitEventLoop() {
3636
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
3737
}
3838

39-
void GRPCClient::SendBeginPass() {
40-
for (auto& it : channels_) {
41-
VLOG(3) << "send begin pass to: " << it.first;
42-
this->AsyncSendBeginPass(it.first);
43-
}
44-
this->Wait();
45-
}
46-
47-
void GRPCClient::SendEndPass() {
48-
for (auto& it : channels_) {
49-
VLOG(3) << "send end pass to " << it.first;
50-
this->AsyncSendEndPass(it.first);
39+
void GRPCClient::SendComplete() {
40+
std::unique_lock<std::mutex> lk(completed_mutex_);
41+
if (!completed_) {
42+
for (auto& it : channels_) {
43+
VLOG(3) << "send complete message to " << it.first;
44+
this->AsyncSendComplete(it.first);
45+
}
46+
PADDLE_ENFORCE(this->Wait(), "internal grpc error");
47+
completed_ = true;
5148
}
52-
this->Wait();
5349
}
5450

5551
GRPCClient::~GRPCClient() {
@@ -239,32 +235,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
239235
req_count_++;
240236
}
241237

242-
void GRPCClient::AsyncSendBeginPass(const std::string& ep, int64_t time_out) {
238+
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
243239
const auto ch = GetChannel(ep);
244240

245241
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
246242
s->Prepare(time_out);
247243

248244
sendrecv::VariableMessage req;
249-
req.set_varname(BEGIN_PASS_MESSAGE);
245+
req.set_varname(COMPLETE_MESSAGE);
250246
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
251247
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
252248
req_count_++;
253249
}
254250

255-
void GRPCClient::AsyncSendEndPass(const std::string& ep, int64_t time_out) {
256-
const auto ch = GetChannel(ep);
257-
258-
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
259-
s->Prepare(time_out);
260-
261-
sendrecv::VariableMessage req;
262-
req.set_varname(END_PASS_MESSAGE);
263-
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
264-
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
265-
req_count_++;
266-
}
267-
268251
void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
269252
const std::string& dir,
270253
int64_t time_out) {

paddle/fluid/operators/distributed/grpc_client.h

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ class CheckpointNotifyProcessor : public BaseProcessor {
174174

175175
class GRPCClient : public RPCClient {
176176
public:
177-
GRPCClient() : ok_(true) {}
177+
GRPCClient() : ok_(true), completed_(false) {}
178178
virtual ~GRPCClient();
179179

180180
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
@@ -201,17 +201,12 @@ class GRPCClient : public RPCClient {
201201
void AsyncCheckpointNotify(const std::string& ep, const std::string& dir,
202202
int64_t time_out = FLAGS_rpc_deadline) override;
203203

204-
void AsyncSendBeginPass(const std::string& ep,
205-
int64_t time_out = FLAGS_rpc_deadline) override;
206-
207-
void AsyncSendEndPass(const std::string& ep,
208-
int64_t time_out = FLAGS_rpc_deadline) override;
204+
void AsyncSendComplete(const std::string& ep,
205+
int64_t time_out = FLAGS_rpc_deadline) override;
209206

210207
bool Wait() override;
211208

212-
void SendBeginPass() override;
213-
214-
void SendEndPass() override;
209+
void SendComplete() override;
215210

216211
protected:
217212
void InitImpl() override;
@@ -238,6 +233,10 @@ class GRPCClient : public RPCClient {
238233
// mutex for GetChannel thread safety
239234
std::mutex chan_mutex_;
240235
DISABLE_COPY_AND_ASSIGN(GRPCClient);
236+
237+
// mutex for sending complete message only once
238+
std::mutex completed_mutex_;
239+
bool completed_;
241240
};
242241

243242
} // namespace distributed

paddle/fluid/operators/distributed/request_handler.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
4343
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
4444
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
4545
#define COMPLETE_MESSAGE "COMPLETE@RECV"
46-
#define BEGIN_PASS_MESSAGE "BEGIN_PASS@RECV"
47-
#define END_PASS_MESSAGE "END_PASS@RECV"
4846

4947
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
5048
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"

paddle/fluid/operators/distributed/request_handler_impl.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,9 @@ bool RequestSendHandler::Handle(const std::string& varname,
5555
if (varname == BATCH_BARRIER_MESSAGE) {
5656
VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE";
5757
rpc_server_->IncreaseBatchBarrier(kRequestSend);
58-
} else if (varname == BEGIN_PASS_MESSAGE) {
59-
VLOG(3) << "sync: recv begin pass message";
60-
rpc_server_->WaitCond(kRequestSend);
61-
rpc_server_->BeginPass();
58+
} else if (varname == COMPLETE_MESSAGE) {
59+
VLOG(3) << "sync: recv complete message";
60+
rpc_server_->Complete();
6261
} else {
6362
VLOG(3) << "sync: received var_name: " << varname;
6463
rpc_server_->WaitCond(kRequestSend);
@@ -94,14 +93,12 @@ bool RequestGetHandler::Handle(const std::string& varname,
9493
if (varname == FETCH_BARRIER_MESSAGE) {
9594
VLOG(3) << "sync: recv fetch barrier message";
9695
rpc_server_->IncreaseBatchBarrier(kRequestGet);
97-
} else if (varname == END_PASS_MESSAGE) {
98-
rpc_server_->EndPass();
9996
} else {
10097
rpc_server_->WaitCond(kRequestGet);
10198
*outvar = scope_->FindVar(varname);
10299
}
103100
} else {
104-
if (varname != FETCH_BARRIER_MESSAGE && varname != END_PASS_MESSAGE) {
101+
if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) {
105102
*outvar = scope_->FindVar(varname);
106103
}
107104
}

paddle/fluid/operators/distributed/rpc_client.h

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,13 @@ class RPCClient {
6060
const std::string& dir,
6161
int64_t time_out = FLAGS_rpc_deadline) = 0;
6262

63-
virtual void AsyncSendBeginPass(const std::string& ep,
64-
int64_t time_out = FLAGS_rpc_deadline) = 0;
63+
virtual void AsyncSendComplete(const std::string& ep,
64+
int64_t time_out = FLAGS_rpc_deadline) = 0;
6565

66-
virtual void AsyncSendEndPass(const std::string& ep,
67-
int64_t time_out = FLAGS_rpc_deadline) = 0;
68-
69-
// BeginePass/EndPass tells all the pserver that start/end a pass, so that
70-
// the pserver can increase/reduce it's barrier count, and continue to train
66+
// Complete tells all the pserver instances that finishe the training,
67+
// the pserver can reduce it's barrier count, and continue to train
7168
// with other trainers.
72-
virtual void SendBeginPass() = 0;
73-
virtual void SendEndPass() = 0;
69+
virtual void SendComplete() = 0;
7470

7571
virtual bool Wait() = 0;
7672

paddle/fluid/operators/distributed/rpc_server.cc

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,7 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
6464
}
6565
}
6666

67-
void RPCServer::BeginPass() {
68-
VLOG(4) << "RPCServer begin increase pass barrier";
69-
{
70-
std::unique_lock<std::mutex> lock(mutex_);
71-
client_num_++;
72-
VLOG(4) << "increase client_num to: " << client_num_;
73-
}
74-
barrier_cond_.notify_all();
75-
}
76-
77-
void RPCServer::EndPass() {
78-
VLOG(4) << "RPCServer begin increase pass barrier";
67+
void RPCServer::Complete() {
7968
{
8069
std::unique_lock<std::mutex> lock(mutex_);
8170
client_num_--;
@@ -87,6 +76,11 @@ void RPCServer::EndPass() {
8776
barrier_cond_.notify_all();
8877
}
8978

79+
int RPCServer::GetClientNum() {
80+
std::unique_lock<std::mutex> lock(mutex_);
81+
return client_num_;
82+
}
83+
9084
void RPCServer::ResetBarrierCounter() {
9185
VLOG(3) << "RPCServer ResetBarrierCounter ";
9286
std::unique_lock<std::mutex> lock(mutex_);

0 commit comments

Comments
 (0)