Skip to content

Commit d2d6e8f

Browse files
authored
cherrypick grpc fixes (#11692)
1 parent 5778040 commit d2d6e8f

File tree

8 files changed

+43
-38
lines changed

8 files changed

+43
-38
lines changed

cmake/external/grpc.cmake

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ ExternalProject_Add(
4040
# NOTE(wuyi):
4141
# this package is generated by following steps:
4242
# 1. git clone -b v1.8.x https://github.com/grpc/grpc.git
43-
# 2. submodule update --init
43+
# 2. git submodule update --init
4444
# 3. keep only zlib, cares, protobuf, boringssl under "third_party",
4545
# checkout and clean other dirs under third_party
4646
# 4. remove .git, and package the directory.
47-
URL "http://paddlepaddledeps.bj.bcebos.com/grpc-v1.8.x.tar.gz"
48-
URL_MD5 "c9c58ee7d0e8929a63155af6a2ecdbd0"
47+
URL "http://paddlepaddledeps.bj.bcebos.com/grpc-v1.10.x.tar.gz"
48+
URL_MD5 "1f268a2aff6759839dccd256adcc91cf"
4949
PREFIX ${GRPC_SOURCES_DIR}
5050
UPDATE_COMMAND ""
5151
CONFIGURE_COMMAND ""

paddle/fluid/operators/distributed/grpc_client.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,14 +258,15 @@ void GRPCClient::Proceed() {
258258
}
259259

260260
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
261-
// TODO(Yancey1989): make grpc client completely thread-safe
262261
std::lock_guard<std::mutex> guard(chan_mutex_);
263262
auto it = channels_.find(ep);
264263
if (it != channels_.end()) {
265264
return it->second;
266265
}
267266

267+
// Channel configurations:
268268
grpc::ChannelArguments args;
269+
args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
269270
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
270271
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
271272
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

paddle/fluid/operators/distributed/grpc_client.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class BaseProcessor {
7272
virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
7373
context_.reset(new grpc::ClientContext());
7474
var_h_ = var_info;
75+
context_->set_wait_for_ready(true);
7576

7677
std::chrono::system_clock::time_point deadline =
7778
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
@@ -81,6 +82,7 @@ class BaseProcessor {
8182

8283
virtual void Prepare(int64_t time_out) {
8384
context_.reset(new grpc::ClientContext());
85+
context_->set_wait_for_ready(true);
8486

8587
std::chrono::system_clock::time_point deadline =
8688
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
@@ -172,26 +174,24 @@ class GRPCClient : public RPCClient {
172174

173175
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
174176
const framework::Scope& scope, const std::string& var_name,
175-
int64_t time_out = RPCClient::rpc_time_out) override;
177+
int64_t time_out = FLAGS_grpc_deadline) override;
176178

177179
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
178180
const framework::Scope& scope, const std::string& var_name,
179-
int64_t time_out = RPCClient::rpc_time_out) override;
181+
int64_t time_out = FLAGS_grpc_deadline) override;
180182

181183
bool AsyncPrefetchVar(const std::string& ep,
182184
const platform::DeviceContext& ctx,
183185
const framework::Scope& scope,
184186
const std::string& in_var_name,
185187
const std::string& out_var_name,
186-
int64_t time_out = RPCClient::rpc_time_out) override;
188+
int64_t time_out = FLAGS_grpc_deadline) override;
187189

188-
void AsyncSendBatchBarrier(
189-
const std::string& ep,
190-
int64_t time_out = RPCClient::rpc_time_out) override;
190+
void AsyncSendBatchBarrier(const std::string& ep,
191+
int64_t time_out = FLAGS_grpc_deadline) override;
191192

192-
void AsyncSendFetchBarrier(
193-
const std::string& ep,
194-
int64_t time_out = RPCClient::rpc_time_out) override;
193+
void AsyncSendFetchBarrier(const std::string& ep,
194+
int64_t time_out = FLAGS_grpc_deadline) override;
195195

196196
void Wait() override;
197197

@@ -207,7 +207,7 @@ class GRPCClient : public RPCClient {
207207
void Proceed();
208208

209209
void AsyncSendComplete(const std::string& ep,
210-
int64_t time_out = RPCClient::rpc_time_out);
210+
int64_t time_out = FLAGS_grpc_deadline);
211211

212212
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
213213

paddle/fluid/operators/distributed/grpc_server.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class RequestSend final : public RequestBase {
8484

8585
void Process() override {
8686
std::string varname = GetReqName();
87-
VLOG(3) << "RequestSend var_name:" << varname;
87+
VLOG(4) << "RequestSend var_name:" << varname;
8888

8989
auto scope = request_->GetMutableLocalScope();
9090
auto invar = request_->GetVar();
@@ -119,7 +119,7 @@ class RequestGet final : public RequestBase {
119119
void Process() override {
120120
// proc request.
121121
std::string varname = request_.varname();
122-
VLOG(3) << "RequestGet " << varname;
122+
VLOG(4) << "RequestGet " << varname;
123123

124124
auto scope = request_handler_->scope();
125125
auto invar = scope->FindVar(varname);
@@ -165,7 +165,7 @@ class RequestPrefetch final : public RequestBase {
165165
// prefetch process...
166166
std::string in_var_name = request_->Varname();
167167
std::string out_var_name = request_->OutVarname();
168-
VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name
168+
VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
169169
<< " out_var_name: " << out_var_name;
170170

171171
auto scope = request_->GetMutableLocalScope();
@@ -188,10 +188,10 @@ class RequestPrefetch final : public RequestBase {
188188
};
189189

190190
void AsyncGRPCServer::WaitServerReady() {
191-
VLOG(3) << "AsyncGRPCServer is wait server ready";
191+
VLOG(4) << "AsyncGRPCServer is wait server ready";
192192
std::unique_lock<std::mutex> lock(this->mutex_ready_);
193193
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
194-
VLOG(3) << "AsyncGRPCServer WaitSeverReady";
194+
VLOG(4) << "AsyncGRPCServer WaitSeverReady";
195195
}
196196

197197
void AsyncGRPCServer::StartServer() {
@@ -230,7 +230,7 @@ void AsyncGRPCServer::StartServer() {
230230
for (int i = 0; i < threadnum; i++) {
231231
rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind(
232232
&AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f)));
233-
VLOG(3) << t.first << " creates threads!";
233+
VLOG(4) << t.first << " creates threads!";
234234
}
235235
}
236236

@@ -247,15 +247,15 @@ void AsyncGRPCServer::StartServer() {
247247
auto& threads = t.second;
248248
for (size_t i = 0; i < threads.size(); ++i) {
249249
threads[i]->join();
250-
VLOG(3) << t.first << " threads ends!";
250+
VLOG(4) << t.first << " threads ends!";
251251
}
252252
}
253253
}
254254

255255
void AsyncGRPCServer::ShutdownQueue() {
256256
for (auto& t : rpc_cq_) {
257257
t.second->Shutdown();
258-
VLOG(3) << t.first << " shutdown!";
258+
VLOG(4) << t.first << " queue shutdown!";
259259
}
260260
}
261261

@@ -264,15 +264,15 @@ void AsyncGRPCServer::ShutDownImpl() {
264264
is_shut_down_ = true;
265265
ShutdownQueue();
266266

267-
VLOG(3) << "server_ shutdown!";
267+
VLOG(4) << "server_ shutdown!";
268268
server_->Shutdown();
269269
}
270270

271271
void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
272272
int req_id) {
273273
std::unique_lock<std::mutex> lock(cq_mutex_);
274274
if (is_shut_down_) {
275-
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
275+
VLOG(4) << "shutdown, do not TryToRegisterNewSendOne";
276276
return;
277277
}
278278

paddle/fluid/operators/distributed/rpc_client.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/operators/distributed/rpc_client.h"
16+
#include "gflags/gflags.h"
17+
18+
// default to 3min to avoid temprary network failures.
19+
DEFINE_int32(grpc_deadline, 180000, "deadline timeouts for grpc");
1620

1721
namespace paddle {
1822
namespace operators {

paddle/fluid/operators/distributed/rpc_client.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
#pragma once
1616

1717
#include <string>
18+
#include "gflags/gflags.h"
1819

1920
#include "paddle/fluid/framework/data_type.h"
2021
#include "paddle/fluid/framework/lod_tensor.h"
2122
#include "paddle/fluid/framework/scope.h"
2223

24+
DECLARE_int32(grpc_deadline);
25+
2326
namespace paddle {
2427
namespace operators {
2528
namespace distributed {
@@ -32,26 +35,26 @@ class RPCClient {
3235
const platform::DeviceContext& ctx,
3336
const framework::Scope& scope,
3437
const std::string& var_name,
35-
int64_t time_out = rpc_time_out) = 0;
38+
int64_t time_out = FLAGS_grpc_deadline) = 0;
3639

3740
virtual bool AsyncGetVar(const std::string& ep,
3841
const platform::DeviceContext& ctx,
3942
const framework::Scope& scope,
4043
const std::string& var_name,
41-
int64_t time_out = rpc_time_out) = 0;
44+
int64_t time_out = FLAGS_grpc_deadline) = 0;
4245

4346
virtual bool AsyncPrefetchVar(const std::string& ep,
4447
const platform::DeviceContext& ctx,
4548
const framework::Scope& scope,
4649
const std::string& in_var_name,
4750
const std::string& out_var_name,
48-
int64_t time_out = rpc_time_out) = 0;
51+
int64_t time_out = FLAGS_grpc_deadline) = 0;
4952

50-
virtual void AsyncSendBatchBarrier(const std::string& ep,
51-
int64_t time_out = rpc_time_out) = 0;
53+
virtual void AsyncSendBatchBarrier(
54+
const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0;
5255

53-
virtual void AsyncSendFetchBarrier(const std::string& ep,
54-
int64_t time_out = rpc_time_out) = 0;
56+
virtual void AsyncSendFetchBarrier(
57+
const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0;
5558

5659
// SendComplete tells all the server that current trainer have no more data
5760
// to train, so that the pserver can reduce it's barrier count, and continue
@@ -60,8 +63,6 @@ class RPCClient {
6063

6164
virtual void Wait() = 0;
6265

63-
static constexpr int64_t rpc_time_out = 120 * 1000;
64-
6566
template <typename T>
6667
static RPCClient* GetInstance() {
6768
std::call_once(init_flag_, &RPCClient::Init<T>);

paddle/fluid/operators/distributed/rpc_server.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,12 @@ void RPCServer::WaitBarrier(const std::string& rpc_name) {
4747
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
4848
});
4949

50-
VLOG(3) << "batch_barrier_:" << barrier_counter_[rpc_name];
50+
VLOG(3) << "batch_barrier_: " << rpc_name << " "
51+
<< barrier_counter_[rpc_name];
5152
}
5253

5354
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
54-
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
55+
VLOG(4) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
5556
int b = 0;
5657
std::unique_lock<std::mutex> lock(mutex_);
5758
b = ++barrier_counter_[rpc_name];
@@ -100,7 +101,7 @@ void RPCServer::SetCond(const std::string& rpc_name) {
100101
}
101102

102103
void RPCServer::WaitCond(const std::string& rpc_name) {
103-
VLOG(3) << "RPCServer WaitCond " << rpc_name;
104+
VLOG(4) << "RPCServer WaitCond " << rpc_name;
104105
int cond = 0;
105106
{
106107
std::unique_lock<std::mutex> lock(mutex_);

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ void ListenAndServOp::RunSyncLoop(
165165

166166
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
167167
framework::ProgramDesc *program) const {
168-
VLOG(3) << "RunAsyncLoop in";
169168
// grad name to block id
170169
std::unordered_map<std::string, int32_t> grad_to_block_id;
171170
std::unordered_map<int32_t, std::string> id_to_grad;
@@ -203,7 +202,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
203202
request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
204203
request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
205204

206-
VLOG(3) << "RunAsyncLoop into while";
207205
while (true) {
208206
if (rpc_service_->IsExit()) {
209207
LOG(INFO) << "get exit!rpc_processor break!";

0 commit comments

Comments
 (0)