Skip to content

Commit 2028a8e

Browse files
authored
Add rpc_client interface. (#11154)
1 parent ca2d6d3 commit 2028a8e

14 files changed

+191
-89
lines changed

paddle/fluid/operators/detail/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
if(WITH_DISTRIBUTE)
22
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
3-
request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
3+
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
44
selected_rows memory)
55
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
66
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -25,29 +25,15 @@ namespace paddle {
2525
namespace operators {
2626
namespace detail {
2727

28-
std::once_flag RPCClient::init_flag_;
28+
void GRPCClient::InitImpl() { InitEventLoop(); }
2929

30-
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
31-
32-
RPCClient* RPCClient::GetInstance() {
33-
std::call_once(init_flag_, &RPCClient::Init);
34-
return rpc_client_.get();
35-
}
36-
37-
void RPCClient::Init() {
38-
if (rpc_client_.get() == nullptr) {
39-
rpc_client_.reset(new RPCClient());
40-
}
41-
rpc_client_->InitEventLoop();
42-
}
43-
44-
void RPCClient::InitEventLoop() {
30+
void GRPCClient::InitEventLoop() {
4531
// start the client process thread
4632
// TODO(wuyi): can make this in a threadpool
47-
client_thread_.reset(new std::thread(std::bind(&RPCClient::Proceed, this)));
33+
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
4834
}
4935

50-
RPCClient::~RPCClient() {
36+
GRPCClient::~GRPCClient() {
5137
Wait();
5238
cq_.Shutdown();
5339
{
@@ -59,11 +45,10 @@ RPCClient::~RPCClient() {
5945
client_thread_->join();
6046
}
6147

62-
bool RPCClient::AsyncSendVariable(const std::string& ep,
63-
const platform::DeviceContext& ctx,
64-
const framework::Scope& scope,
65-
const std::string& var_name,
66-
int64_t time_out) {
48+
bool GRPCClient::AsyncSendVar(const std::string& ep,
49+
const platform::DeviceContext& ctx,
50+
const framework::Scope& scope,
51+
const std::string& var_name, int64_t time_out) {
6752
const platform::DeviceContext* p_ctx = &ctx;
6853
const std::string ep_val = ep;
6954
const std::string var_name_val = var_name;
@@ -113,11 +98,10 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
11398
result->Swap(&tmp);
11499
}
115100

116-
bool RPCClient::AsyncGetVariable(const std::string& ep,
117-
const platform::DeviceContext& ctx,
118-
const framework::Scope& scope,
119-
const std::string& var_name,
120-
int64_t time_out) {
101+
bool GRPCClient::AsyncGetVar(const std::string& ep,
102+
const platform::DeviceContext& ctx,
103+
const framework::Scope& scope,
104+
const std::string& var_name, int64_t time_out) {
121105
const platform::DeviceContext* p_ctx = &ctx;
122106
const std::string ep_val = ep;
123107
const std::string var_name_val = var_name;
@@ -155,12 +139,12 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
155139
return true;
156140
}
157141

158-
bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
159-
const platform::DeviceContext& ctx,
160-
const framework::Scope& scope,
161-
const std::string& in_var_name,
162-
const std::string& out_var_name,
163-
int64_t time_out) {
142+
bool GRPCClient::AsyncPrefetchVar(const std::string& ep,
143+
const platform::DeviceContext& ctx,
144+
const framework::Scope& scope,
145+
const std::string& in_var_name,
146+
const std::string& out_var_name,
147+
int64_t time_out) {
164148
const platform::DeviceContext* p_ctx = &ctx;
165149
const std::string ep_val = ep;
166150
const std::string in_var_name_val = in_var_name;
@@ -198,7 +182,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
198182
return true;
199183
}
200184

201-
void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
185+
void GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
186+
int64_t time_out) {
202187
const auto ch = GetChannel(ep);
203188

204189
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
@@ -211,7 +196,8 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
211196
req_count_++;
212197
}
213198

214-
void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
199+
void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
200+
int64_t time_out) {
215201
const auto ch = GetChannel(ep);
216202
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
217203
s->Prepare(time_out);
@@ -223,12 +209,12 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
223209
req_count_++;
224210
}
225211

226-
void RPCClient::Wait() {
212+
void GRPCClient::Wait() {
227213
std::unique_lock<std::mutex> lk(sync_mutex_);
228214
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
229215
}
230216

231-
void RPCClient::Proceed() {
217+
void GRPCClient::Proceed() {
232218
void* tag = nullptr;
233219
bool ok = false;
234220

@@ -251,7 +237,7 @@ void RPCClient::Proceed() {
251237
}
252238
}
253239

254-
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
240+
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
255241
// TODO(Yancey1989): make grpc client completely thread-safe
256242
std::lock_guard<std::mutex> guard(chan_mutex_);
257243
auto it = channels_.find(ep);

paddle/fluid/operators/detail/grpc_client.h

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ limitations under the License. */
3838
#include "paddle/fluid/framework/lod_tensor.h"
3939
#include "paddle/fluid/framework/scope.h"
4040
#include "paddle/fluid/framework/selected_rows.h"
41+
#include "paddle/fluid/operators/detail/rpc_client.h"
4142
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
4243
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
4344

@@ -164,47 +165,46 @@ class FetchBarrierProcessor : public BaseProcessor {
164165
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
165166
};
166167

167-
class RPCClient {
168+
class GRPCClient : public RPCClient {
168169
public:
169-
RPCClient() {}
170-
~RPCClient();
170+
GRPCClient() {}
171+
virtual ~GRPCClient();
171172

172-
static RPCClient* GetInstance();
173+
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
174+
const framework::Scope& scope, const std::string& var_name,
175+
int64_t time_out = RPCClient::rpc_time_out) override;
173176

174-
bool AsyncSendVariable(const std::string& ep,
175-
const platform::DeviceContext& ctx,
176-
const framework::Scope& scope,
177-
const std::string& var_name,
178-
int64_t time_out = 600 * 1000);
177+
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
178+
const framework::Scope& scope, const std::string& var_name,
179+
int64_t time_out = RPCClient::rpc_time_out) override;
179180

180-
bool AsyncGetVariable(const std::string& ep,
181+
bool AsyncPrefetchVar(const std::string& ep,
181182
const platform::DeviceContext& ctx,
182183
const framework::Scope& scope,
183-
const std::string& var_name,
184-
int64_t time_out = 600 * 1000);
184+
const std::string& in_var_name,
185+
const std::string& out_var_name,
186+
int64_t time_out = RPCClient::rpc_time_out) override;
185187

186-
bool AsyncPrefetchVariable(const std::string& ep,
187-
const platform::DeviceContext& ctx,
188-
const framework::Scope& scope,
189-
const std::string& in_var_name,
190-
const std::string& out_var_name,
191-
int64_t time_out = 600 * 1000);
188+
void AsyncSendBatchBarrier(
189+
const std::string& ep,
190+
int64_t time_out = RPCClient::rpc_time_out) override;
192191

193-
void AsyncSendBatchBarrier(const std::string& ep,
194-
int64_t time_out = 600 * 1000);
192+
void AsyncSendFetchBarrier(
193+
const std::string& ep,
194+
int64_t time_out = RPCClient::rpc_time_out) override;
195195

196-
void AsyncSendFetchBarrier(const std::string& ep,
197-
int64_t time_out = 600 * 1000);
196+
void Wait() override;
198197

199-
void Wait();
198+
protected:
199+
void InitImpl() override;
200+
201+
private:
200202
// InitEventLoop should only be called by Init()
201203
void InitEventLoop();
202204

203-
private:
204205
void Proceed();
206+
205207
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
206-
// Init is called by GetInstance.
207-
static void Init();
208208

209209
private:
210210
grpc::CompletionQueue cq_;
@@ -218,9 +218,7 @@ class RPCClient {
218218

219219
// mutex for GetChannel thread safety
220220
std::mutex chan_mutex_;
221-
static std::unique_ptr<RPCClient> rpc_client_;
222-
static std::once_flag init_flag_;
223-
DISABLE_COPY_AND_ASSIGN(RPCClient);
221+
DISABLE_COPY_AND_ASSIGN(GRPCClient);
224222
};
225223

226224
} // namespace detail

paddle/fluid/operators/detail/grpc_server_test.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License. */
1919
#include "gtest/gtest.h"
2020
#include "paddle/fluid/operators/detail/grpc_client.h"
2121
#include "paddle/fluid/operators/detail/grpc_server.h"
22+
#include "paddle/fluid/operators/detail/rpc_client.h"
2223

2324
#include "paddle/fluid/framework/block_desc.h"
2425
#include "paddle/fluid/framework/op_registry.h"
@@ -123,7 +124,8 @@ TEST(PREFETCH, CPU) {
123124
std::thread server_thread(StartServer);
124125
g_rpc_service->WaitServerReady();
125126

126-
detail::RPCClient* client = detail::RPCClient::GetInstance();
127+
detail::RPCClient* client =
128+
detail::RPCClient::GetInstance<detail::GRPCClient>();
127129
int port = g_rpc_service->GetSelectedPort();
128130
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
129131

@@ -137,7 +139,7 @@ TEST(PREFETCH, CPU) {
137139
std::string in_var_name("ids");
138140
std::string out_var_name("out");
139141

140-
client->AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name);
142+
client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name);
141143
client->Wait();
142144
auto var = scope.Var(out_var_name);
143145
auto value = var->GetMutable<framework::SelectedRows>()->value();
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/detail/rpc_client.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
namespace detail {
20+
21+
std::once_flag RPCClient::init_flag_;
22+
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
23+
24+
} // namespace detail
25+
} // namespace operators
26+
} // namespace paddle
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <string>
18+
19+
#include "paddle/fluid/framework/data_type.h"
20+
#include "paddle/fluid/framework/lod_tensor.h"
21+
#include "paddle/fluid/framework/scope.h"
22+
23+
namespace paddle {
24+
namespace operators {
25+
namespace detail {
26+
27+
class RPCClient {
28+
public:
29+
virtual bool AsyncSendVar(const std::string& ep,
30+
const platform::DeviceContext& ctx,
31+
const framework::Scope& scope,
32+
const std::string& var_name,
33+
int64_t time_out = rpc_time_out) = 0;
34+
35+
virtual bool AsyncGetVar(const std::string& ep,
36+
const platform::DeviceContext& ctx,
37+
const framework::Scope& scope,
38+
const std::string& var_name,
39+
int64_t time_out = rpc_time_out) = 0;
40+
41+
virtual bool AsyncPrefetchVar(const std::string& ep,
42+
const platform::DeviceContext& ctx,
43+
const framework::Scope& scope,
44+
const std::string& in_var_name,
45+
const std::string& out_var_name,
46+
int64_t time_out = rpc_time_out) = 0;
47+
48+
virtual void AsyncSendBatchBarrier(const std::string& ep,
49+
int64_t time_out = rpc_time_out) = 0;
50+
51+
virtual void AsyncSendFetchBarrier(const std::string& ep,
52+
int64_t time_out = rpc_time_out) = 0;
53+
54+
virtual void Wait() = 0;
55+
56+
static constexpr int64_t rpc_time_out = 120 * 1000;
57+
58+
template <typename T>
59+
static RPCClient* GetInstance() {
60+
std::call_once(init_flag_, &RPCClient::Init<T>);
61+
return rpc_client_.get();
62+
}
63+
64+
// Init is called by GetInstance.
65+
template <typename T>
66+
static void Init() {
67+
if (rpc_client_.get() == nullptr) {
68+
rpc_client_.reset(new T());
69+
rpc_client_->InitImpl();
70+
}
71+
}
72+
73+
protected:
74+
virtual void InitImpl() {}
75+
76+
private:
77+
static std::once_flag init_flag_;
78+
static std::unique_ptr<RPCClient> rpc_client_;
79+
};
80+
} // namespace detail
81+
} // namespace operators
82+
} // namespace paddle

paddle/fluid/operators/fetch_barrier_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License. */
2121
#include "paddle/fluid/framework/op_registry.h"
2222

2323
#include "paddle/fluid/operators/detail/grpc_client.h"
24+
#include "paddle/fluid/operators/detail/rpc_client.h"
2425
#include "paddle/fluid/platform/profiler.h"
2526

2627
namespace paddle {
@@ -43,7 +44,8 @@ class FetchBarrierOp : public framework::OperatorBase {
4344
// For profiling
4445
platform::RecordEvent record_event(Type(), &ctx);
4546

46-
auto rpc_client = detail::RPCClient::GetInstance();
47+
detail::RPCClient* rpc_client =
48+
detail::RPCClient::GetInstance<detail::GRPCClient>();
4749

4850
rpc_client->Wait();
4951

0 commit comments

Comments
 (0)