Skip to content

Commit 34865f2

Browse files
authored
Trainer send term signal (#11220)
* wip * use executor.complete to end trainer * fix build * fix build with distribute off * fix typo * fix cmake typo * fix build
1 parent ca4d528 commit 34865f2

File tree

12 files changed

+80
-14
lines changed

12 files changed

+80
-14
lines changed

cmake/configure.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ endif()
118118
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}")
119119
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}")
120120

121+
if(WITH_DISTRIBUTE)
122+
add_definitions(-DPADDLE_WITH_DISTRIBUTE)
123+
endif()
124+
121125
if(WITH_GOLANG)
122126
# we need to symlink Paddle directory into GOPATH. If we
123127
# don't do it and we have code that depends on Paddle, go

paddle/fluid/framework/CMakeLists.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,13 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
8383

8484
cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
8585

86-
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
87-
framework_proto glog lod_rank_table feed_fetch_method)
86+
if(WITH_DISTRIBUTE)
87+
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr)
88+
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
89+
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
90+
else()
91+
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method)
92+
endif()
8893

8994

9095
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)

paddle/fluid/framework/executor.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ limitations under the License. */
2020
#include "paddle/fluid/framework/lod_tensor_array.h"
2121
#include "paddle/fluid/framework/op_registry.h"
2222
#include "paddle/fluid/framework/reader.h"
23+
#ifdef PADDLE_WITH_DISTRIBUTE
24+
#include "paddle/fluid/operators/detail/grpc_client.h"
25+
#endif
2326
#include "paddle/fluid/platform/place.h"
2427
#include "paddle/fluid/platform/profiler.h"
2528

@@ -44,6 +47,14 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
4447

4548
Executor::Executor(const platform::Place& place) : place_(place) {}
4649

50+
#ifdef PADDLE_WITH_DISTRIBUTE
51+
void Executor::Complete() {
52+
::paddle::operators::detail::RPCClient::GetInstance<
53+
::paddle::operators::detail::GRPCClient>()
54+
->SendComplete();
55+
}
56+
#endif
57+
4758
void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
4859
if (var_type == proto::VarType::LOD_TENSOR) {
4960
var->GetMutable<LoDTensor>();

paddle/fluid/framework/executor.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ class Executor {
4444

4545
explicit Executor(const platform::Place& place);
4646

47+
#ifdef PADDLE_WITH_DISTRIBUTE
48+
/*
49+
* Sending signal to pserver to mark current trainer stop.
50+
*/
51+
void Complete();
52+
#endif
53+
4754
/* @Brief
4855
* Runtime evaluation of the given ProgramDesc under certain Scope
4956
*

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ void GRPCClient::InitEventLoop() {
3434
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
3535
}
3636

37+
void GRPCClient::SendComplete() {
38+
for (auto& it : channels_) {
39+
this->AsyncSendComplete(it.first);
40+
}
41+
}
42+
3743
GRPCClient::~GRPCClient() {
3844
Wait();
3945
cq_.Shutdown();
@@ -210,6 +216,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
210216
req_count_++;
211217
}
212218

219+
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
220+
const auto ch = GetChannel(ep);
221+
222+
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
223+
s->Prepare(time_out);
224+
225+
sendrecv::VariableMessage req;
226+
req.set_varname(COMPLETE_MESSAGE);
227+
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
228+
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
229+
req_count_++;
230+
}
231+
213232
void GRPCClient::Wait() {
214233
std::unique_lock<std::mutex> lk(sync_mutex_);
215234
sync_cond_.wait(lk, [this] { return req_count_ == 0; });

paddle/fluid/operators/detail/grpc_client.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ class GRPCClient : public RPCClient {
195195

196196
void Wait() override;
197197

198+
void SendComplete() override;
199+
198200
protected:
199201
void InitImpl() override;
200202

@@ -204,6 +206,9 @@ class GRPCClient : public RPCClient {
204206

205207
void Proceed();
206208

209+
void AsyncSendComplete(const std::string& ep,
210+
int64_t time_out = RPCClient::rpc_time_out);
211+
207212
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
208213

209214
private:

paddle/fluid/operators/detail/request_handler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ constexpr char kRequestPrefetch[] = "RequestPrefetch";
4040
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
4141
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
4242
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
43+
#define COMPLETE_MESSAGE "COMPLETE@RECV"
4344

4445
class RPCServer;
4546

paddle/fluid/operators/detail/request_handler_impl.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ bool RequestSendHandler::Handle(const std::string& varname,
4949
if (varname == BATCH_BARRIER_MESSAGE) {
5050
VLOG(3) << "sync: recv batch barrier message";
5151
rpc_server_->IncreaseBatchBarrier(kRequestSend);
52+
} else if (varname == COMPLETE_MESSAGE) {
53+
VLOG(3) << "sync: recv complete message";
54+
rpc_server_->DecreaseClientNum();
5255
} else {
5356
VLOG(3) << "sync: received var_name: " << varname;
5457
if (sync_mode_) {

paddle/fluid/operators/detail/rpc_client.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ class RPCClient {
5353
virtual void AsyncSendFetchBarrier(const std::string& ep,
5454
int64_t time_out = rpc_time_out) = 0;
5555

56+
// SendComplete tells all the server that current trainer have no more data
57+
// to train, so that the pserver can reduce it's barrier count, and continue
58+
// to train with other trainers.
59+
virtual void SendComplete() = 0;
60+
5661
virtual void Wait() = 0;
5762

5863
static constexpr int64_t rpc_time_out = 120 * 1000;

paddle/fluid/operators/detail/rpc_server.cc

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ void RPCServer::SavePort() const {
4343

4444
void RPCServer::WaitBarrier(const std::string& rpc_name) {
4545
std::unique_lock<std::mutex> lock(this->mutex_);
46-
barrier_cond_.wait(lock, [=] {
46+
barrier_cond_.wait(lock, [this, &rpc_name] {
4747
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
4848
});
4949

@@ -53,19 +53,23 @@ void RPCServer::WaitBarrier(const std::string& rpc_name) {
5353
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
5454
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
5555
int b = 0;
56-
{
57-
std::unique_lock<std::mutex> lock(mutex_);
58-
b = ++barrier_counter_[rpc_name];
59-
}
60-
61-
VLOG(3) << "RPCServer IncreaseBatchBarrier " << rpc_name
62-
<< ", barrier_count:" << b << ", fan_in" << client_num_;
63-
56+
std::unique_lock<std::mutex> lock(mutex_);
57+
b = ++barrier_counter_[rpc_name];
6458
if (b >= client_num_) {
59+
lock.unlock();
6560
barrier_cond_.notify_all();
61+
lock.lock();
6662
}
6763
}
6864

65+
void RPCServer::DecreaseClientNum() {
66+
{
67+
std::unique_lock<std::mutex> lock(mutex_);
68+
client_num_--;
69+
}
70+
barrier_cond_.notify_all();
71+
}
72+
6973
void RPCServer::ResetBarrierCounter() {
7074
VLOG(3) << "RPCServer ResetBarrierCounter ";
7175
std::unique_lock<std::mutex> lock(mutex_);

0 commit comments

Comments
 (0)