Skip to content

Commit bf0c90f

Browse files
committed
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into fix_async_update_failed
2 parents 86e09b3 + 67ab324 commit bf0c90f

19 files changed

+136
-113
lines changed

cmake/external/grpc.cmake

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ ExternalProject_Add(
4040
# NOTE(wuyi):
4141
# this package is generated by following steps:
4242
# 1. git clone -b v1.8.x https://github.com/grpc/grpc.git
43-
# 2. submodule update --init
43+
# 2. git submodule update --init
4444
# 3. keep only zlib, cares, protobuf, boringssl under "third_party",
4545
# checkout and clean other dirs under third_party
4646
# 4. remove .git, and package the directory.
47-
URL "http://paddlepaddledeps.bj.bcebos.com/grpc-v1.8.x.tar.gz"
48-
URL_MD5 "c9c58ee7d0e8929a63155af6a2ecdbd0"
47+
URL "http://paddlepaddledeps.bj.bcebos.com/grpc-v1.10.x.tar.gz"
48+
URL_MD5 "1f268a2aff6759839dccd256adcc91cf"
4949
PREFIX ${GRPC_SOURCES_DIR}
5050
UPDATE_COMMAND ""
5151
CONFIGURE_COMMAND ""

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
470470
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
471471
const OpDesc &op) const {
472472
int op_dev_id = -1;
473-
if (op.Type() == "split_byref") {
473+
if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") {
474474
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
475475
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
476476
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());

paddle/fluid/operators/assign_value_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ AssignValue operator
7070

7171
namespace ops = paddle::operators;
7272

73-
REGISTER_OPERATOR(assign_value, ops::AssignValueOp, ops::AssignValueOpMaker);
73+
REGISTER_OPERATOR(assign_value, ops::AssignValueOp, ops::AssignValueOpMaker,
74+
paddle::framework::EmptyGradOpMaker);
7475
REGISTER_OP_CPU_KERNEL(assign_value, ops::AssignValueKernel<int>,
7576
ops::AssignValueKernel<float>);

paddle/fluid/operators/distributed/grpc_client.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,14 +269,15 @@ void GRPCClient::Proceed() {
269269
}
270270

271271
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
272-
// TODO(Yancey1989): make grpc client completely thread-safe
273272
std::lock_guard<std::mutex> guard(chan_mutex_);
274273
auto it = channels_.find(ep);
275274
if (it != channels_.end()) {
276275
return it->second;
277276
}
278277

278+
// Channel configurations:
279279
grpc::ChannelArguments args;
280+
args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
280281
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
281282
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
282283
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

paddle/fluid/operators/distributed/grpc_client.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class BaseProcessor {
7676
virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
7777
context_.reset(new grpc::ClientContext());
7878
var_h_ = var_info;
79+
context_->set_wait_for_ready(true);
7980

8081
std::chrono::system_clock::time_point deadline =
8182
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
@@ -85,6 +86,7 @@ class BaseProcessor {
8586

8687
virtual void Prepare(int64_t time_out) {
8788
context_.reset(new grpc::ClientContext());
89+
context_->set_wait_for_ready(true);
8890

8991
std::chrono::system_clock::time_point deadline =
9092
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
@@ -176,26 +178,24 @@ class GRPCClient : public RPCClient {
176178

177179
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
178180
const framework::Scope& scope, const std::string& var_name,
179-
int64_t time_out = RPCClient::rpc_time_out) override;
181+
int64_t time_out = FLAGS_grpc_deadline) override;
180182

181183
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
182184
const framework::Scope& scope, const std::string& var_name,
183-
int64_t time_out = RPCClient::rpc_time_out) override;
185+
int64_t time_out = FLAGS_grpc_deadline) override;
184186

185187
bool AsyncPrefetchVar(const std::string& ep,
186188
const platform::DeviceContext& ctx,
187189
const framework::Scope& scope,
188190
const std::string& in_var_name,
189191
const std::string& out_var_name,
190-
int64_t time_out = RPCClient::rpc_time_out) override;
192+
int64_t time_out = FLAGS_grpc_deadline) override;
191193

192-
void AsyncSendBatchBarrier(
193-
const std::string& ep,
194-
int64_t time_out = RPCClient::rpc_time_out) override;
194+
void AsyncSendBatchBarrier(const std::string& ep,
195+
int64_t time_out = FLAGS_grpc_deadline) override;
195196

196-
void AsyncSendFetchBarrier(
197-
const std::string& ep,
198-
int64_t time_out = RPCClient::rpc_time_out) override;
197+
void AsyncSendFetchBarrier(const std::string& ep,
198+
int64_t time_out = FLAGS_grpc_deadline) override;
199199

200200
void Wait() override;
201201

@@ -211,7 +211,7 @@ class GRPCClient : public RPCClient {
211211
void Proceed();
212212

213213
void AsyncSendComplete(const std::string& ep,
214-
int64_t time_out = RPCClient::rpc_time_out);
214+
int64_t time_out = FLAGS_grpc_deadline);
215215

216216
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
217217

paddle/fluid/operators/distributed/grpc_server.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class RequestSend final : public RequestBase {
9797

9898
void Process() override {
9999
std::string varname = GetReqName();
100-
VLOG(3) << "RequestSend var_name:" << varname;
100+
VLOG(4) << "RequestSend var_name:" << varname;
101101

102102
auto scope = request_->GetMutableLocalScope();
103103
auto invar = request_->GetVar();
@@ -132,7 +132,7 @@ class RequestGet final : public RequestBase {
132132
void Process() override {
133133
// proc request.
134134
std::string varname = request_.varname();
135-
VLOG(3) << "RequestGet " << varname;
135+
VLOG(4) << "RequestGet " << varname;
136136

137137
auto scope = request_handler_->scope();
138138
auto invar = scope->FindVar(varname);
@@ -178,7 +178,7 @@ class RequestPrefetch final : public RequestBase {
178178
// prefetch process...
179179
std::string in_var_name = request_->Varname();
180180
std::string out_var_name = request_->OutVarname();
181-
VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name
181+
VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
182182
<< " out_var_name: " << out_var_name;
183183

184184
auto scope = request_->GetMutableLocalScope();
@@ -201,10 +201,10 @@ class RequestPrefetch final : public RequestBase {
201201
};
202202

203203
void AsyncGRPCServer::WaitServerReady() {
204-
VLOG(3) << "AsyncGRPCServer is wait server ready";
204+
VLOG(4) << "AsyncGRPCServer is wait server ready";
205205
std::unique_lock<std::mutex> lock(this->mutex_ready_);
206206
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
207-
VLOG(3) << "AsyncGRPCServer WaitSeverReady";
207+
VLOG(4) << "AsyncGRPCServer WaitSeverReady";
208208
}
209209

210210
void AsyncGRPCServer::StartServer() {
@@ -243,7 +243,7 @@ void AsyncGRPCServer::StartServer() {
243243
for (int i = 0; i < threadnum; i++) {
244244
rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind(
245245
&AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f)));
246-
VLOG(3) << t.first << " creates threads!";
246+
VLOG(4) << t.first << " creates threads!";
247247
}
248248
}
249249

@@ -260,15 +260,15 @@ void AsyncGRPCServer::StartServer() {
260260
auto& threads = t.second;
261261
for (size_t i = 0; i < threads.size(); ++i) {
262262
threads[i]->join();
263-
VLOG(3) << t.first << " threads ends!";
263+
VLOG(4) << t.first << " threads ends!";
264264
}
265265
}
266266
}
267267

268268
void AsyncGRPCServer::ShutdownQueue() {
269269
for (auto& t : rpc_cq_) {
270270
t.second->Shutdown();
271-
VLOG(3) << t.first << " shutdown!";
271+
VLOG(4) << t.first << " queue shutdown!";
272272
}
273273
}
274274

@@ -277,15 +277,15 @@ void AsyncGRPCServer::ShutDownImpl() {
277277
is_shut_down_ = true;
278278
ShutdownQueue();
279279

280-
VLOG(3) << "server_ shutdown!";
280+
VLOG(4) << "server_ shutdown!";
281281
server_->Shutdown();
282282
}
283283

284284
void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
285285
int req_id) {
286286
std::unique_lock<std::mutex> lock(cq_mutex_);
287287
if (is_shut_down_) {
288-
LOG(WARNING) << "shutdown, do not TryToRegisterNewSendOne";
288+
VLOG(4) << "shutdown, do not TryToRegisterNewSendOne";
289289
return;
290290
}
291291

paddle/fluid/operators/distributed/rpc_client.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/operators/distributed/rpc_client.h"
16+
#include "gflags/gflags.h"
17+
18+
// default to 3min to avoid temprary network failures.
19+
DEFINE_int32(grpc_deadline, 180000, "deadline timeouts for grpc");
1620

1721
namespace paddle {
1822
namespace operators {

paddle/fluid/operators/distributed/rpc_client.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
#pragma once
1616

1717
#include <string>
18+
#include "gflags/gflags.h"
1819

1920
#include "paddle/fluid/framework/data_type.h"
2021
#include "paddle/fluid/framework/lod_tensor.h"
2122
#include "paddle/fluid/framework/scope.h"
2223

24+
DECLARE_int32(grpc_deadline);
25+
2326
namespace paddle {
2427
namespace operators {
2528
namespace distributed {
@@ -32,26 +35,26 @@ class RPCClient {
3235
const platform::DeviceContext& ctx,
3336
const framework::Scope& scope,
3437
const std::string& var_name,
35-
int64_t time_out = rpc_time_out) = 0;
38+
int64_t time_out = FLAGS_grpc_deadline) = 0;
3639

3740
virtual bool AsyncGetVar(const std::string& ep,
3841
const platform::DeviceContext& ctx,
3942
const framework::Scope& scope,
4043
const std::string& var_name,
41-
int64_t time_out = rpc_time_out) = 0;
44+
int64_t time_out = FLAGS_grpc_deadline) = 0;
4245

4346
virtual bool AsyncPrefetchVar(const std::string& ep,
4447
const platform::DeviceContext& ctx,
4548
const framework::Scope& scope,
4649
const std::string& in_var_name,
4750
const std::string& out_var_name,
48-
int64_t time_out = rpc_time_out) = 0;
51+
int64_t time_out = FLAGS_grpc_deadline) = 0;
4952

50-
virtual void AsyncSendBatchBarrier(const std::string& ep,
51-
int64_t time_out = rpc_time_out) = 0;
53+
virtual void AsyncSendBatchBarrier(
54+
const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0;
5255

53-
virtual void AsyncSendFetchBarrier(const std::string& ep,
54-
int64_t time_out = rpc_time_out) = 0;
56+
virtual void AsyncSendFetchBarrier(
57+
const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0;
5558

5659
// SendComplete tells all the server that current trainer have no more data
5760
// to train, so that the pserver can reduce it's barrier count, and continue
@@ -60,8 +63,6 @@ class RPCClient {
6063

6164
virtual void Wait() = 0;
6265

63-
static constexpr int64_t rpc_time_out = 120 * 1000;
64-
6566
template <typename T>
6667
static RPCClient* GetInstance() {
6768
std::call_once(init_flag_, &RPCClient::Init<T>);

paddle/fluid/operators/distributed/rpc_server.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,12 @@ void RPCServer::WaitBarrier(const std::string& rpc_name) {
4747
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
4848
});
4949

50-
VLOG(3) << "batch_barrier_:" << barrier_counter_[rpc_name];
50+
VLOG(3) << "batch_barrier_: " << rpc_name << " "
51+
<< barrier_counter_[rpc_name];
5152
}
5253

5354
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
54-
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
55+
VLOG(4) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
5556
int b = 0;
5657
std::unique_lock<std::mutex> lock(mutex_);
5758
b = ++barrier_counter_[rpc_name];
@@ -100,7 +101,7 @@ void RPCServer::SetCond(const std::string& rpc_name) {
100101
}
101102

102103
void RPCServer::WaitCond(const std::string& rpc_name) {
103-
VLOG(3) << "RPCServer WaitCond " << rpc_name;
104+
VLOG(4) << "RPCServer WaitCond " << rpc_name;
104105
int cond = 0;
105106
{
106107
std::unique_lock<std::mutex> lock(mutex_);

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ void ListenAndServOp::RunSyncLoop(
165165
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
166166
framework::ProgramDesc *program,
167167
framework::Scope *recv_scope) const {
168-
VLOG(3) << "RunAsyncLoop in";
169168
// grad name to block id
170169
std::unordered_map<std::string, int32_t> grad_to_block_id;
171170
std::unordered_map<int32_t, std::string> id_to_grad;
@@ -207,7 +206,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
207206
request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
208207
request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
209208

210-
VLOG(3) << "RunAsyncLoop into while";
211209
while (true) {
212210
if (rpc_service_->IsExit()) {
213211
LOG(INFO) << "get exit!rpc_processor break!";

0 commit comments

Comments
 (0)