Skip to content

Commit 0042ba9

Browse files
author
Yancey
authored
Merge pull request #12127 from Yancey1989/enforce_rpc_timeout
Enforce rpc timeout
2 parents 325fbc4 + d14afce commit 0042ba9

File tree

12 files changed

+28
-18
lines changed

12 files changed

+28
-18
lines changed

paddle/fluid/framework/parallel_executor.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ ParallelExecutor::ParallelExecutor(
104104
}
105105

106106
if (member_->local_scopes_.size() != 1 && local_scopes.empty()) {
107-
BCastParamsToDevs(bcast_vars);
107+
BCastParamsToDevices(bcast_vars);
108108
}
109109
// Startup Program has been run. All local scopes has correct parameters.
110110

@@ -140,7 +140,7 @@ ParallelExecutor::ParallelExecutor(
140140
member_->places_, std::move(member_->executor_)));
141141
}
142142

143-
void ParallelExecutor::BCastParamsToDevs(
143+
void ParallelExecutor::BCastParamsToDevices(
144144
const std::unordered_set<std::string> &vars) const {
145145
// the initializing bcast, all vars would be bcast from device(0),
146146
// otherwise

paddle/fluid/framework/parallel_executor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class ParallelExecutor {
6666
void Run(const std::vector<std::string> &fetch_tensors,
6767
const std::string &fetched_var_name);
6868

69-
void BCastParamsToDevs(const std::unordered_set<std::string> &vars) const;
69+
void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const;
7070

7171
private:
7272
ParallelExecutorPrivate *member_;

paddle/fluid/operators/checkpoint_notify_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class CheckpointNotifyOp : public framework::OperatorBase {
4848
VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name
4949
<< " and dir:" << dir << " to " << epmap[i];
5050
}
51-
rpc_client->Wait();
51+
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
5252
}
5353
};
5454

paddle/fluid/operators/distributed/grpc_client.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,10 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
281281
req_count_++;
282282
}
283283

284-
void GRPCClient::Wait() {
284+
bool GRPCClient::Wait() {
285285
std::unique_lock<std::mutex> lk(sync_mutex_);
286-
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
286+
sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
287+
return ok_;
287288
}
288289

289290
void GRPCClient::Proceed() {
@@ -297,6 +298,14 @@ void GRPCClient::Proceed() {
297298
if (c->status_.ok()) {
298299
VLOG(3) << c->var_h_.String() << " process";
299300
c->Process();
301+
} else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
302+
LOG(ERROR) << c->var_h_.String()
303+
<< " meets grpc error:" << c->status_.error_message();
304+
{
305+
std::lock_guard<std::mutex> lk(sync_mutex_);
306+
ok_ = false;
307+
}
308+
sync_cond_.notify_all();
300309
} else {
301310
LOG(FATAL) << c->var_h_.String()
302311
<< " meets grpc error:" << c->status_.error_message();

paddle/fluid/operators/distributed/grpc_client.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ class CheckpointNotifyProcessor : public BaseProcessor {
188188

189189
class GRPCClient : public RPCClient {
190190
public:
191-
GRPCClient() {}
191+
GRPCClient() : ok_(true) {}
192192
virtual ~GRPCClient();
193193

194194
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
@@ -221,7 +221,7 @@ class GRPCClient : public RPCClient {
221221
void AsyncSendEndPass(const std::string& ep,
222222
int64_t time_out = FLAGS_rpc_deadline) override;
223223

224-
void Wait() override;
224+
bool Wait() override;
225225

226226
void SendBeginPass() override;
227227

@@ -247,6 +247,7 @@ class GRPCClient : public RPCClient {
247247
std::mutex sync_mutex_;
248248
std::condition_variable sync_cond_;
249249
std::atomic<int64_t> req_count_{0};
250+
bool ok_;
250251

251252
// mutex for GetChannel thread safety
252253
std::mutex chan_mutex_;

paddle/fluid/operators/distributed/rpc_client.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class RPCClient {
7272
virtual void SendBeginPass() = 0;
7373
virtual void SendEndPass() = 0;
7474

75-
virtual void Wait() = 0;
75+
virtual bool Wait() = 0;
7676

7777
template <typename T>
7878
static RPCClient* GetInstance() {

paddle/fluid/operators/fetch_barrier_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase {
4545
distributed::RPCClient* rpc_client =
4646
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
4747

48-
rpc_client->Wait();
48+
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
4949

5050
for (auto& ep : eps) {
5151
VLOG(3) << "fetch barrier, ep: " << ep;
5252
rpc_client->AsyncSendFetchBarrier(ep);
5353
}
54-
rpc_client->Wait();
54+
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
5555
}
5656
};
5757

paddle/fluid/operators/prefetch_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase {
5353
VLOG(3) << "don't send no-initialied variable: " << ins[i];
5454
}
5555
}
56-
rpc_client->Wait();
56+
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
5757
}
5858
};
5959

paddle/fluid/operators/recv_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase {
5151
rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]);
5252
}
5353
if (sync_mode) {
54-
rpc_client->Wait();
54+
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
5555
}
5656
}
5757
};

paddle/fluid/operators/send_barrier_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ class SendBarrierOp : public framework::OperatorBase {
5050
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
5151

5252
// need to wait before sending send_barrier message
53-
rpc_client->Wait();
53+
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
5454
if (sync_mode) {
5555
for (auto& ep : eps) {
5656
VLOG(3) << "send barrier, ep: " << ep;
5757
rpc_client->AsyncSendBatchBarrier(ep);
5858
}
59-
rpc_client->Wait();
59+
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
6060
}
6161
}
6262
};

0 commit comments

Comments
 (0)