Skip to content

Commit 036a90f

Browse files
typhoonzerogongweibao
authored andcommitted
Refine rpc client wait sync (#11132)
1 parent a385803 commit 036a90f

File tree

10 files changed

+80
-90
lines changed

10 files changed

+80
-90
lines changed

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 41 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,25 @@ void RPCClient::Init() {
3838
if (rpc_client_.get() == nullptr) {
3939
rpc_client_.reset(new RPCClient());
4040
}
41+
rpc_client_->InitEventLoop();
42+
}
43+
44+
void RPCClient::InitEventLoop() {
45+
// start the client process thread
46+
// TODO(wuyi): can make this in a threadpool
47+
client_thread_.reset(new std::thread(std::bind(&RPCClient::Proceed, this)));
48+
}
49+
50+
RPCClient::~RPCClient() {
51+
Wait();
52+
cq_.Shutdown();
53+
{
54+
std::lock_guard<std::mutex> guard(chan_mutex_);
55+
for (auto& it : channels_) {
56+
it.second.reset();
57+
}
58+
}
59+
client_thread_->join();
4160
}
4261

4362
bool RPCClient::AsyncSendVariable(const std::string& ep,
@@ -204,70 +223,37 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
204223
req_count_++;
205224
}
206225

207-
bool RPCClient::Wait() {
208-
VLOG(3) << "RPCClient begin Wait()"
209-
<< " req_count_:" << req_count_;
210-
if (req_count_ <= 0) {
211-
return true;
212-
}
213-
const size_t kReqCnt = req_count_;
214-
bool a[kReqCnt];
215-
std::vector<std::future<void>> waits(req_count_);
216-
std::mutex mu;
217-
218-
for (int i = 0; i < req_count_; i++) {
219-
waits[i] = framework::AsyncIO([i, &a, &mu, this] {
220-
bool ret = Proceed();
221-
std::lock_guard<std::mutex> l(mu);
222-
a[i] = ret;
223-
});
224-
}
225-
226-
for (int i = 0; i < req_count_; i++) {
227-
waits[i].wait();
228-
}
229-
230-
int last_req_count = req_count_;
231-
req_count_ = 0;
232-
233-
for (int i = 0; i < last_req_count; i++) {
234-
if (!a[i]) {
235-
return false;
236-
}
237-
}
238-
239-
return true;
226+
void RPCClient::Wait() {
227+
std::unique_lock<std::mutex> lk(sync_mutex_);
228+
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
240229
}
241230

242-
bool RPCClient::Proceed() {
243-
void* tag = NULL;
231+
void RPCClient::Proceed() {
232+
void* tag = nullptr;
244233
bool ok = false;
245234

246-
// request counts.
247-
if (!cq_.Next(&tag, &ok)) {
248-
LOG(ERROR) << "Get meets CompletionQueue error";
249-
return false;
250-
}
251-
252-
GPR_ASSERT(ok);
253-
PADDLE_ENFORCE(tag);
254-
255-
// TODO(gongwb): add more retries.
256-
BaseProcessor* c = static_cast<BaseProcessor*>(tag);
257-
if (!c->status_.ok()) {
258-
LOG(ERROR) << "proc param error:" << c->var_h_.String()
259-
<< " grpc error:" << c->status_.error_message();
235+
while (cq_.Next(&tag, &ok)) {
236+
BaseProcessor* c = static_cast<BaseProcessor*>(tag);
237+
GPR_ASSERT(ok);
238+
PADDLE_ENFORCE(c);
239+
if (c->status_.ok()) {
240+
c->Process();
241+
} else {
242+
LOG(ERROR) << "var: " << c->var_h_.String()
243+
<< " grpc error:" << c->status_.error_message();
244+
}
260245
delete c;
261-
return false;
246+
{
247+
std::lock_guard<std::mutex> lk(sync_mutex_);
248+
req_count_--;
249+
}
250+
sync_cond_.notify_all();
262251
}
263-
264-
c->Process();
265-
delete c;
266-
return true;
267252
}
253+
268254
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
269255
// TODO(Yancey1989): make grpc client completely thread-safe
270-
std::unique_lock<std::mutex> lock(mutex_);
256+
std::lock_guard<std::mutex> guard(chan_mutex_);
271257
auto it = channels_.find(ep);
272258
if (it != channels_.end()) {
273259
return it->second;

paddle/fluid/operators/detail/grpc_client.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@ limitations under the License. */
1616

1717
#include <time.h>
1818

19-
#include <chrono> // NOLINT
19+
#include <chrono> // NOLINT
20+
#include <condition_variable> // NOLINT
2021
#include <ctime>
2122
#include <functional>
2223
#include <iostream>
2324
#include <map>
2425
#include <mutex> // NOLINT
2526
#include <string>
27+
#include <thread> // NOLINT
2628
#include <vector>
2729

30+
#include "grpc++/channel.h"
2831
#include "grpc++/generic/generic_stub.h"
2932
#include "grpc++/grpc++.h"
3033
#include "grpc++/support/byte_buffer.h"
@@ -164,6 +167,7 @@ class FetchBarrierProcessor : public BaseProcessor {
164167
class RPCClient {
165168
public:
166169
RPCClient() {}
170+
~RPCClient();
167171

168172
static RPCClient* GetInstance();
169173

@@ -192,19 +196,28 @@ class RPCClient {
192196
void AsyncSendFetchBarrier(const std::string& ep,
193197
int64_t time_out = 600 * 1000);
194198

195-
bool Wait();
199+
void Wait();
200+
// InitEventLoop should only be called by Init()
201+
void InitEventLoop();
196202

197203
private:
198-
bool Proceed();
204+
void Proceed();
199205
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
200206
// Init is called by GetInstance.
201207
static void Init();
202208

203209
private:
204210
grpc::CompletionQueue cq_;
205-
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
211+
std::unordered_map<std::string, std::shared_ptr<grpc::Channel>> channels_;
212+
std::unique_ptr<std::thread> client_thread_;
213+
214+
// mutex for Wait client sync
215+
std::mutex sync_mutex_;
216+
std::condition_variable sync_cond_;
206217
std::atomic<int64_t> req_count_{0};
207-
std::mutex mutex_;
218+
219+
// mutex for GetChannel thread safety
220+
std::mutex chan_mutex_;
208221
static std::unique_ptr<RPCClient> rpc_client_;
209222
static std::once_flag init_flag_;
210223
DISABLE_COPY_AND_ASSIGN(RPCClient);

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,7 @@ class RequestSend final : public RequestBase {
6868
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
6969
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
7070
}
71-
7271
virtual ~RequestSend() {}
73-
7472
std::string GetReqName() override { return request_->Varname(); }
7573

7674
void Process() override {
@@ -82,7 +80,6 @@ class RequestSend final : public RequestBase {
8280
framework::Variable* outvar = nullptr;
8381

8482
request_handler_->Handle(varname, scope, invar, &outvar);
85-
8683
status_ = FINISH;
8784
responder_.Finish(reply_, ::grpc::Status::OK,
8885
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
@@ -125,7 +122,6 @@ class RequestGet final : public RequestBase {
125122
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
126123
&reply_);
127124
}
128-
129125
status_ = FINISH;
130126
responder_.Finish(reply_, ::grpc::Status::OK,
131127
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
@@ -170,10 +166,9 @@ class RequestPrefetch final : public RequestBase {
170166

171167
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
172168
&reply_);
173-
174-
status_ = FINISH;
175169
responder_.Finish(reply_, ::grpc::Status::OK,
176170
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
171+
status_ = FINISH;
177172
}
178173

179174
protected:

paddle/fluid/operators/detail/grpc_server_test.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,6 @@ void StartServer() {
113113
std::thread server_thread(
114114
std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get()));
115115

116-
// FIXME(gongwb): don't use hard time.
117-
sleep(10);
118-
LOG(INFO) << "got nccl id and stop server...";
119-
g_rpc_service->ShutDown();
120116
server_thread.join();
121117
}
122118

@@ -127,7 +123,7 @@ TEST(PREFETCH, CPU) {
127123
std::thread server_thread(StartServer);
128124
g_rpc_service->WaitServerReady();
129125

130-
detail::RPCClient client;
126+
detail::RPCClient* client = detail::RPCClient::GetInstance();
131127
int port = g_rpc_service->GetSelectedPort();
132128
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
133129

@@ -141,8 +137,8 @@ TEST(PREFETCH, CPU) {
141137
std::string in_var_name("ids");
142138
std::string out_var_name("out");
143139

144-
client.AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name);
145-
client.Wait();
140+
client->AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name);
141+
client->Wait();
146142
auto var = scope.Var(out_var_name);
147143
auto value = var->GetMutable<framework::SelectedRows>()->value();
148144
auto ptr = value.mutable_data<float>(place);
@@ -152,6 +148,7 @@ TEST(PREFETCH, CPU) {
152148
}
153149
}
154150

151+
g_rpc_service->ShutDown();
155152
server_thread.join();
156153
LOG(INFO) << "begin reset";
157154
g_rpc_service.reset(nullptr);

paddle/fluid/operators/fetch_barrier_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase {
4545

4646
auto rpc_client = detail::RPCClient::GetInstance();
4747

48-
PADDLE_ENFORCE(rpc_client->Wait());
48+
rpc_client->Wait();
4949

5050
for (auto& ep : eps) {
5151
VLOG(3) << "fetch barrier, ep: " << ep;
5252
rpc_client->AsyncSendFetchBarrier(ep);
5353
}
54-
PADDLE_ENFORCE(rpc_client->Wait());
54+
rpc_client->Wait();
5555
}
5656
};
5757

paddle/fluid/operators/prefetch_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase {
5353
VLOG(3) << "don't send no-initialied variable: " << ins[i];
5454
}
5555
}
56-
PADDLE_ENFORCE(rpc_client->Wait());
56+
rpc_client->Wait();
5757
}
5858
};
5959

paddle/fluid/operators/recv_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase {
5151
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
5252
}
5353
if (sync_mode) {
54-
PADDLE_ENFORCE(rpc_client->Wait());
54+
rpc_client->Wait();
5555
}
5656
}
5757
};

paddle/fluid/operators/send_barrier_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ class SendBarrierOp : public framework::OperatorBase {
4949
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
5050

5151
// need to wait before sending send_barrier message
52-
PADDLE_ENFORCE(rpc_client->Wait());
52+
rpc_client->Wait();
5353
if (sync_mode) {
5454
for (auto& ep : eps) {
5555
VLOG(3) << "send barrier, ep: " << ep;
5656
rpc_client->AsyncSendBatchBarrier(ep);
5757
}
58-
PADDLE_ENFORCE(rpc_client->Wait());
58+
rpc_client->Wait();
5959
}
6060
}
6161
};

paddle/fluid/operators/send_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,28 +59,28 @@ class SendOp : public framework::OperatorBase {
5959
VLOG(3) << "don't send no-initialied variable: " << ins[i];
6060
}
6161
}
62-
PADDLE_ENFORCE(rpc_client->Wait());
62+
rpc_client->Wait();
6363

6464
if (sync_mode) {
6565
for (auto& ep : endpoints) {
6666
VLOG(3) << "batch barrier, ep: " << ep;
6767
rpc_client->AsyncSendBatchBarrier(ep);
6868
}
69-
PADDLE_ENFORCE(rpc_client->Wait());
69+
rpc_client->Wait();
7070
}
7171

7272
if (outs.size() > 0) {
7373
for (size_t i = 0; i < outs.size(); i++) {
7474
VLOG(2) << "getting " << outs[i] << " from " << epmap[i];
7575
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
7676
}
77-
PADDLE_ENFORCE(rpc_client->Wait());
77+
rpc_client->Wait();
7878
// tell pservers that current trainer have called fetch
7979
for (auto& ep : endpoints) {
8080
VLOG(2) << "send fetch barrier, ep: " << ep;
8181
rpc_client->AsyncSendFetchBarrier(ep);
8282
}
83-
PADDLE_ENFORCE(rpc_client->Wait());
83+
rpc_client->Wait();
8484
}
8585
}
8686
};

paddle/fluid/operators/test_send_nccl_id.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ void StartServer() {
6161
std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get()));
6262

6363
g_rpc_service->SetCond(detail::kRequestSend);
64-
std::cout << "before WaitFanInOfSend" << std::endl;
6564
g_rpc_service->WaitBarrier(detail::kRequestSend);
6665

6766
LOG(INFO) << "got nccl id and stop server...";
@@ -88,12 +87,12 @@ TEST(SendNcclId, GrpcServer) {
8887
int port = g_rpc_service->GetSelectedPort();
8988

9089
std::string ep = string::Sprintf("127.0.0.1:%d", port);
91-
detail::RPCClient client;
92-
LOG(INFO) << "connect to server" << ep;
93-
client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
94-
client.Wait();
95-
client.AsyncSendBatchBarrier(ep);
96-
client.Wait();
90+
detail::RPCClient* client = detail::RPCClient::GetInstance();
91+
LOG(INFO) << "connect to server " << ep;
92+
client->AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
93+
client->Wait();
94+
client->AsyncSendBatchBarrier(ep);
95+
client->Wait();
9796

9897
server_thread.join();
9998
g_rpc_service.reset(nullptr);

0 commit comments

Comments
 (0)