Skip to content

Commit fe8f28c

Browse files
authored
Add GetVariableNoBarrier on brpc. (#15488)
1 parent 981fc2b commit fe8f28c

File tree

6 files changed

+111
-24
lines changed

6 files changed

+111
-24
lines changed

paddle/fluid/operators/distributed/CMakeLists.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if(WITH_GRPC)
2020
collective_client.cc collective_server.cc
2121
${GRPC_SRCS}
2222
PROTO send_recv.proto
23-
DEPS lod_tensor selected_rows_functor memory)
23+
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS})
2424

2525
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
2626
set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS})
@@ -32,15 +32,17 @@ else()
3232
set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc)
3333
set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
3434

35+
set(BRPC_DEPS brpc ssl crypto protobuf leveldb snappystream snappy zlib)
36+
3537
brpc_library(sendrecvop_rpc SRCS sendrecvop_utils.cc
3638
request_handler_impl.cc rpc_client.cc rpc_server.cc
3739
variable_response.cc
3840
collective_client.cc collective_server.cc
3941
${BRPC_SRCS}
4042
PROTO send_recv.proto
41-
DEPS lod_tensor selected_rows memory)
43+
DEPS lod_tensor selected_rows memory scope ${BRPC_DEPS})
4244

43-
set(RPC_DEPS sendrecvop_rpc brpc ssl crypto protobuf leveldb snappystream snappy zlib)
45+
set(RPC_DEPS sendrecvop_rpc ${BRPC_DEPS})
4446
cc_test(brpc_serde_test SRCS brpc/brpc_serde_test.cc
4547
DEPS ${RPC_DEPS} gflags glog executor proto_desc lookup_sparse_table_op SERIAL)
4648
endif()

paddle/fluid/operators/distributed/brpc/brpc_client.cc

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ VarHandlePtr BRPCClient::AsyncSendVar(const std::string& ep,
6262
const std::string var_name_val = var_name;
6363
const framework::Scope* p_scope = &scope;
6464
const auto ch_ptr = GetChannel(ep_val);
65-
const std::string method = "SendRPC";
65+
const std::string method = kSendRPC;
6666
VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
6767

6868
framework::AsyncIO([=] {
@@ -156,15 +156,18 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
156156
const platform::DeviceContext& ctx,
157157
const framework::Scope& scope,
158158
const std::string& var_name,
159+
const std::string& out_var_name,
159160
const std::string& method_name,
160161
int64_t time_out) {
161162
const platform::DeviceContext* p_ctx = &ctx;
162163
const std::string ep_val = ep;
163164
const std::string var_name_val = var_name;
165+
const std::string out_varname_val = out_var_name;
164166
const framework::Scope* p_scope = &scope;
165167
const auto ch_ptr = GetChannel(ep_val);
166-
const std::string method = "GetRPC";
167-
VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
168+
const std::string method = kGetRPC;
169+
VarHandlePtr var_h(
170+
new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
168171

169172
framework::AsyncIO([=] {
170173
auto ch_ctx = ch_ptr->Pop();
@@ -175,15 +178,18 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
175178

176179
sendrecv::VariableMessage req;
177180
req.set_varname(var_name_val);
181+
req.set_out_varname(out_varname_val);
178182
req.set_trainer_id(trainer_id_);
179183

180184
google::protobuf::Closure* done = brpc::NewCallback(
181185
&HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
182186

183187
platform::RecordRPCEvent record_event(method, p_ctx);
184188

185-
if (method_name == "GetMonomerVariable") {
189+
if (method_name == kGetMonomerRPC) {
186190
ch_ctx->stub->GetMonomerVariable(cntl, &req, response, done);
191+
} else if (method_name == kGetNoBarrierRPC) {
192+
ch_ctx->stub->GetVariableNoBarrier(cntl, &req, response, done);
187193
} else {
188194
ch_ctx->stub->GetVariable(cntl, &req, response, done);
189195
}
@@ -198,25 +204,39 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
198204
return var_h;
199205
}
200206

207+
VarHandlePtr BRPCClient::AsyncGetVarNoBarrier(
208+
const std::string& ep, const platform::DeviceContext& ctx,
209+
const framework::Scope& scope, const std::string& var_name,
210+
const std::string& out_var_name, int64_t time_out) {
211+
std::string var_name_no_barrier =
212+
string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE);
213+
214+
return _AsyncGetVar(ep, ctx, scope, var_name_no_barrier, out_var_name,
215+
kGetNoBarrierRPC, time_out);
216+
}
217+
201218
VarHandlePtr BRPCClient::AsyncGetMonomerVariable(
202219
const std::string& ep, const platform::DeviceContext& ctx,
203220
const framework::Scope& scope, const std::string& var_name,
204221
int64_t time_out) {
205-
return _AsyncGetVar(ep, ctx, scope, var_name, "GetMonomerVariable", time_out);
222+
return _AsyncGetVar(ep, ctx, scope, var_name, var_name, kGetMonomerRPC,
223+
time_out);
206224
}
207225

208226
VarHandlePtr BRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
209227
const std::string& var_name,
210228
int64_t time_out) {
211-
return AsyncSendMessage(ep, "GetMonomerBarrier", var_name, time_out);
229+
return AsyncSendMessage(ep, kSendMonomerFetchBarrierRPC, var_name, time_out);
212230
}
213231

214232
VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep,
215233
const platform::DeviceContext& ctx,
216234
const framework::Scope& scope,
217235
const std::string& var_name,
236+
const std::string& out_var_name,
218237
int64_t time_out) {
219-
return _AsyncGetVar(ep, ctx, scope, var_name, "GetVariable", time_out);
238+
return _AsyncGetVar(ep, ctx, scope, var_name, out_var_name, kGetRPC,
239+
time_out);
220240
}
221241

222242
VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
@@ -234,7 +254,7 @@ VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
234254
const framework::Scope* p_scope = &scope;
235255
const auto ch_ptr = GetChannel(ep_val);
236256

237-
const std::string method = "PrefetchRPC";
257+
const std::string method = kPrefetchRPC;
238258

239259
VarHandlePtr var_h(
240260
new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
@@ -270,7 +290,7 @@ VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
270290

271291
VarHandlePtr BRPCClient::AsyncSendBatchBarrier(const std::string& ep,
272292
int64_t time_out) {
273-
return AsyncSendMessage(ep, "BatchBarrierRPC", BATCH_BARRIER_MESSAGE,
293+
return AsyncSendMessage(ep, kBatchBarrierRPC, BATCH_BARRIER_MESSAGE,
274294
time_out);
275295
}
276296

@@ -286,7 +306,7 @@ VarHandlePtr BRPCClient::AsyncSendFetchBarrier(const std::string& ep,
286306
sendrecv::VariableMessage req;
287307
req.set_varname(FETCH_BARRIER_MESSAGE);
288308

289-
const std::string method = "FetchBarrierRPC";
309+
const std::string method = kFetchBarrierRPC;
290310
// var handle
291311
VarHandlePtr var_h(
292312
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
@@ -367,7 +387,7 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
367387

368388
VarHandlePtr BRPCClient::AsyncSendComplete(const std::string& ep,
369389
int64_t time_out) {
370-
return AsyncSendMessage(ep, "SendCompleteRPC", COMPLETE_MESSAGE, time_out);
390+
return AsyncSendMessage(ep, kSendCompleteRPC, COMPLETE_MESSAGE, time_out);
371391
}
372392

373393
void BRPCClient::SendComplete() {
@@ -394,9 +414,9 @@ VarHandlePtr BRPCClient::AsyncSendVarMessage(
394414
google::protobuf::Closure* done = brpc::NewCallback(
395415
&HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
396416

397-
if (method_name == "CheckPointNotifyRPC") {
417+
if (method_name == kCheckPointNotifyRPC) {
398418
ch_ctx->stub->CheckpointNotify(cntl, &req, response, done);
399-
} else if (method_name == "GetMonomerBarrier") {
419+
} else if (method_name == kSendMonomerFetchBarrierRPC) {
400420
ch_ctx->stub->GetMonomerBarrier(cntl, &req, response, done);
401421
} else {
402422
ch_ctx->stub->SendVariable(cntl, &req, response, done);

paddle/fluid/operators/distributed/brpc/brpc_client.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class BRPCClient : public RPCClient {
6565
const platform::DeviceContext& ctx,
6666
const framework::Scope& scope,
6767
const std::string& var_name,
68+
const std::string& out_var_name,
6869
int64_t time_out = FLAGS_rpc_deadline) override;
6970

7071
VarHandlePtr AsyncGetMonomerBarrier(
@@ -76,6 +77,13 @@ class BRPCClient : public RPCClient {
7677
const framework::Scope& scope, const std::string& var_name,
7778
int64_t time_out = FLAGS_rpc_deadline) override;
7879

80+
VarHandlePtr AsyncGetVarNoBarrier(const std::string& ep,
81+
const platform::DeviceContext& ctx,
82+
const framework::Scope& scope,
83+
const std::string& var_name,
84+
const std::string& out_varname,
85+
int64_t time_out = FLAGS_rpc_deadline);
86+
7987
VarHandlePtr AsyncPrefetchVar(const std::string& ep,
8088
const platform::DeviceContext& ctx,
8189
const framework::Scope& scope,
@@ -103,6 +111,7 @@ class BRPCClient : public RPCClient {
103111
const platform::DeviceContext& ctx,
104112
const framework::Scope& scope,
105113
const std::string& var_name,
114+
const std::string& out_var_name,
106115
const std::string& method_name,
107116
int64_t time_out = FLAGS_rpc_deadline);
108117

paddle/fluid/operators/distributed/brpc/brpc_server.cc

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ class BRPCServiceImpl : public SendRecvService {
4545
rpc_server_->GetThreadNum(distributed::kRequestGet)));
4646
}
4747

48+
it = rpc_call_map.find(distributed::kRequestGetNoBarrier);
49+
if (it != rpc_call_map.end()) {
50+
request_getnobarrier_h_ = it->second;
51+
getnobarrier_threads_.reset(new paddle::framework::ThreadPool(
52+
rpc_server_->GetThreadNum(distributed::kRequestGetNoBarrier)));
53+
}
54+
4855
it = rpc_call_map.find(distributed::kRequestPrefetch);
4956
if (it != rpc_call_map.end()) {
5057
request_prefetch_h_ = it->second;
@@ -112,6 +119,14 @@ class BRPCServiceImpl : public SendRecvService {
112119
[=] { _GetVariable(cntl_butil, request, response, done); });
113120
}
114121

122+
void GetVariableNoBarrier(google::protobuf::RpcController* cntl_butil,
123+
const VariableMessage* request,
124+
VariableMessage* response,
125+
google::protobuf::Closure* done) override {
126+
getnobarrier_threads_->Run(
127+
[=] { _GetVariableNoBarrier(cntl_butil, request, response, done); });
128+
}
129+
115130
void _GetVariable(google::protobuf::RpcController* cntl_butil,
116131
const VariableMessage* request, VariableMessage* response,
117132
google::protobuf::Closure* done) {
@@ -122,23 +137,59 @@ class BRPCServiceImpl : public SendRecvService {
122137
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
123138

124139
std::string varname = request->varname();
140+
std::string out_varname = request->out_varname();
125141
VLOG(3) << "RequestGet varname:" << varname
142+
<< ", out_varname:" << out_varname
126143
<< ", trainer_id:" << request->trainer_id()
127144
<< ", from:" << cntl->remote_side();
128145

129146
auto scope = request_get_h_->scope();
130-
auto invar = scope->FindVar(varname);
147+
paddle::framework::Variable* invar = nullptr;
148+
int trainer_id = request->trainer_id();
149+
paddle::framework::Variable* outvar = nullptr;
150+
151+
request_get_h_->Handle(varname, scope, invar, &outvar, trainer_id,
152+
out_varname);
153+
154+
if (outvar) {
155+
distributed::SerializeToIOBuf(out_varname, outvar,
156+
*request_get_h_->dev_ctx(), response,
157+
&cntl->response_attachment(), "", false);
158+
}
159+
}
160+
161+
void _GetVariableNoBarrier(google::protobuf::RpcController* cntl_butil,
162+
const VariableMessage* request,
163+
VariableMessage* response,
164+
google::protobuf::Closure* done) {
165+
PADDLE_ENFORCE(request_getnobarrier_h_ != nullptr,
166+
"RequestGetNoBarrier handler should be registed first!");
167+
168+
brpc::ClosureGuard done_guard(done);
169+
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
170+
171+
std::string varname = request->varname();
172+
std::string out_varname = request->out_varname();
131173
int trainer_id = request->trainer_id();
174+
175+
VLOG(3) << "RequestGetNoBarrier varname:" << varname
176+
<< ", out_varname:" << out_varname << ", trainer_id:" << trainer_id
177+
<< ", from:" << cntl->remote_side();
178+
179+
auto scope = request_getnobarrier_h_->scope();
180+
paddle::framework::Variable* invar = nullptr;
132181
paddle::framework::Variable* outvar = nullptr;
133182

134-
request_get_h_->Handle(varname, scope, invar, &outvar, trainer_id);
183+
request_getnobarrier_h_->Handle(varname, scope, invar, &outvar, trainer_id,
184+
out_varname);
135185

136186
if (outvar) {
137-
distributed::SerializeToIOBuf(varname, outvar, *request_get_h_->dev_ctx(),
138-
response, &cntl->response_attachment(), "",
139-
false);
187+
distributed::SerializeToIOBuf(
188+
out_varname, outvar, *request_getnobarrier_h_->dev_ctx(), response,
189+
&cntl->response_attachment(), "", false);
140190
}
141191
}
192+
142193
void PrefetchVariable(google::protobuf::RpcController* cntl_butil,
143194
const VariableMessage* request,
144195
VariableMessage* response,
@@ -282,16 +333,18 @@ class BRPCServiceImpl : public SendRecvService {
282333
private:
283334
distributed::RequestHandler* request_send_h_{nullptr};
284335
distributed::RequestHandler* request_get_h_{nullptr};
336+
distributed::RequestHandler* request_getnobarrier_h_{nullptr};
285337
distributed::RequestHandler* request_prefetch_h_{nullptr};
286338
distributed::RequestHandler* request_checkpoint_h_{nullptr};
287339
distributed::RequestHandler* request_get_monomer_handler_h_{nullptr};
288340
distributed::RequestHandler* request_get_monomer_barrier_handler_h_{nullptr};
289341

290342
distributed::RPCServer* rpc_server_{nullptr};
291343

292-
// FIXME(gongwb): brpc should support process one rpce use one threadpool.
344+
// FIXME(gongwb): brpc should support process one rpc use one threadpool.
293345
std::unique_ptr<paddle::framework::ThreadPool> send_threads_;
294346
std::unique_ptr<paddle::framework::ThreadPool> get_threads_;
347+
std::unique_ptr<paddle::framework::ThreadPool> getnobarrier_threads_;
295348
std::unique_ptr<paddle::framework::ThreadPool> prefetch_threads_;
296349
std::unique_ptr<paddle::framework::ThreadPool> checkpoint_notify_threads_;
297350
};

paddle/scripts/paddle_build.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,8 @@ function run_brpc_test() {
328328
========================================
329329
EOF
330330
set +x
331-
declare -a other_tests=("test_listen_and_serv_op" "system_allocator_test")
331+
declare -a other_tests=("test_listen_and_serv_op" "system_allocator_test" \
332+
"rpc_server_test" "varhandle_test" "collective_server_test" "brpc_serde_test")
332333
all_tests=`ctest -N`
333334

334335
for t in "${other_tests[@]}"

python/paddle/fluid/transpiler/details/checkport.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import time
1717
import socket
1818
from contextlib import closing
19+
from six import string_types
1920

2021

2122
def wait_server_ready(endpoints):
@@ -32,6 +33,7 @@ def wait_server_ready(endpoints):
3233
3334
wait_server_ready(["127.0.0.1:8080", "127.0.0.1:8081"])
3435
"""
36+
assert not isinstance(endpoints, string_types)
3537
while True:
3638
all_ok = True
3739
not_ready_endpoints = []
@@ -45,7 +47,7 @@ def wait_server_ready(endpoints):
4547
all_ok = False
4648
not_ready_endpoints.append(ep)
4749
if not all_ok:
48-
sys.stderr.write("pserver not ready, wait 3 sec to retry...\n")
50+
sys.stderr.write("server not ready, wait 3 sec to retry...\n")
4951
sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints) +
5052
"\n")
5153
sys.stderr.flush()

0 commit comments

Comments
 (0)