Skip to content

Commit 8cee9f6

Browse files
authored
Fix rpcclient's wait action in aync env. (#13307)
1 parent 7f692b8 commit 8cee9f6

File tree

9 files changed

+296
-175
lines changed

9 files changed

+296
-175
lines changed

paddle/fluid/operators/distributed/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ if(WITH_GRPC)
2020
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
2121
cc_test(rpc_server_test SRCS rpc_server_test.cc
2222
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
23+
cc_test(varhandle_test SRCS varhandle_test.cc)
2324
return()
2425
endif()
2526

paddle/fluid/operators/distributed/grpc_client.cc

Lines changed: 69 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -59,40 +59,32 @@ GRPCClient::~GRPCClient() {
5959
}
6060
channels_.clear();
6161
}
62-
6362
client_thread_->join();
6463
}
6564

66-
bool GRPCClient::AsyncSendVar(const std::string& ep,
67-
const platform::DeviceContext& ctx,
68-
const framework::Scope& scope,
69-
const std::string& var_name, int64_t time_out) {
65+
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
66+
const platform::DeviceContext& ctx,
67+
const framework::Scope& scope,
68+
const std::string& var_name,
69+
int64_t time_out) {
7070
const platform::DeviceContext* p_ctx = &ctx;
7171
const std::string ep_val = ep;
7272
const std::string var_name_val = var_name;
7373
const framework::Scope* p_scope = &scope;
7474
const auto ch = GetChannel(ep_val);
75+
SendProcessor* s = new SendProcessor(ch);
76+
VarHandlePtr h(new VarHandle(ep, "Send", var_name_val, p_ctx, p_scope));
77+
s->Prepare(h, time_out);
7578

76-
framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch,
77-
this] {
79+
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
7880
auto* var = p_scope->FindVar(var_name_val);
7981

8082
::grpc::ByteBuffer req;
8183
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req);
8284

83-
// varhandle
84-
VarHandle var_h;
85-
var_h.ep = ep_val;
86-
var_h.scope = p_scope;
87-
var_h.name = var_name_val;
88-
var_h.ctx = p_ctx;
89-
var_h.method = "Send";
90-
91-
VLOG(3) << var_h.String() << " begin";
85+
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
9286

9387
// stub context
94-
SendProcessor* s = new SendProcessor(ch);
95-
s->Prepare(var_h, time_out);
9688
s->response_call_back_ = nullptr;
9789

9890
auto call = s->stub_g_.PrepareUnaryCall(
@@ -102,13 +94,13 @@ bool GRPCClient::AsyncSendVar(const std::string& ep,
10294
});
10395
req_count_++;
10496

105-
return true;
97+
return h;
10698
}
10799

108100
void ProcGetResponse(const VarHandle& var_h,
109101
const ::grpc::ByteBuffer& ret_msg) {
110102
framework::Variable* outvar = nullptr;
111-
DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar);
103+
DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar);
112104
}
113105

114106
template <typename T>
@@ -119,37 +111,30 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
119111
result->Swap(&tmp);
120112
}
121113

122-
bool GRPCClient::AsyncGetVar(const std::string& ep,
123-
const platform::DeviceContext& ctx,
124-
const framework::Scope& scope,
125-
const std::string& var_name, int64_t time_out) {
114+
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
115+
const platform::DeviceContext& ctx,
116+
const framework::Scope& scope,
117+
const std::string& var_name,
118+
int64_t time_out) {
126119
const platform::DeviceContext* p_ctx = &ctx;
127120
const std::string ep_val = ep;
128121
const std::string var_name_val = var_name;
129122
const framework::Scope* p_scope = &scope;
130123
const auto ch = GetChannel(ep_val);
124+
GetProcessor* s = new GetProcessor(ch);
125+
VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope));
126+
s->Prepare(h, time_out);
131127

132-
framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch,
133-
this] {
128+
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
134129
// prepare input
135130
sendrecv::VariableMessage req;
136131
req.set_varname(var_name_val);
137132
::grpc::ByteBuffer buf;
138133
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
139134

140-
// var handle
141-
VarHandle var_h;
142-
var_h.ep = ep_val;
143-
var_h.scope = p_scope;
144-
var_h.name = var_name_val;
145-
var_h.ctx = p_ctx;
146-
var_h.method = "Get";
147-
148-
VLOG(3) << var_h.String() << " begin";
135+
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
149136

150137
// stub context
151-
GetProcessor* s = new GetProcessor(ch);
152-
s->Prepare(var_h, time_out);
153138
s->response_call_back_ = ProcGetResponse;
154139

155140
auto call = s->stub_g_.PrepareUnaryCall(
@@ -160,42 +145,36 @@ bool GRPCClient::AsyncGetVar(const std::string& ep,
160145

161146
req_count_++;
162147

163-
return true;
148+
return h;
164149
}
165150

166-
bool GRPCClient::AsyncPrefetchVar(const std::string& ep,
167-
const platform::DeviceContext& ctx,
168-
const framework::Scope& scope,
169-
const std::string& in_var_name,
170-
const std::string& out_var_name,
171-
int64_t time_out) {
151+
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
152+
const platform::DeviceContext& ctx,
153+
const framework::Scope& scope,
154+
const std::string& in_var_name,
155+
const std::string& out_var_name,
156+
int64_t time_out) {
172157
const platform::DeviceContext* p_ctx = &ctx;
173158
const std::string ep_val = ep;
174159
const std::string in_var_name_val = in_var_name;
175160
const std::string out_var_name_val = out_var_name;
176161
const framework::Scope* p_scope = &scope;
177162
const auto ch = GetChannel(ep_val);
163+
GetProcessor* s = new GetProcessor(ch);
164+
VarHandlePtr h(
165+
new VarHandle(ep, "Prefetch", out_var_name_val, p_ctx, p_scope));
166+
s->Prepare(h, time_out);
178167

179168
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
180-
time_out, ch, this] {
169+
time_out, s, this] {
181170
auto* var = p_scope->FindVar(in_var_name_val);
182171

183172
::grpc::ByteBuffer req;
184173
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
185174

186-
// var handle
187-
VarHandle var_h;
188-
var_h.ep = ep_val;
189-
var_h.scope = p_scope;
190-
var_h.name = out_var_name_val;
191-
var_h.ctx = p_ctx;
192-
var_h.method = "Prefetch";
193-
194-
VLOG(3) << var_h.String() << " begin";
175+
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
195176

196177
// stub context
197-
GetProcessor* s = new GetProcessor(ch);
198-
s->Prepare(var_h, time_out);
199178
s->response_call_back_ = ProcGetResponse;
200179

201180
auto call = s->stub_g_.PrepareUnaryCall(
@@ -206,56 +185,68 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep,
206185
});
207186

208187
req_count_++;
209-
return true;
188+
return h;
210189
}
211190

212-
void GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
213-
int64_t time_out) {
191+
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
192+
int64_t time_out) {
214193
const auto ch = GetChannel(ep);
215194

216195
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
217-
s->Prepare(time_out);
196+
VarHandlePtr h(new VarHandle(ep, "BatchBarrier", BATCH_BARRIER_MESSAGE,
197+
nullptr, nullptr));
198+
s->Prepare(h, time_out);
218199

219200
sendrecv::VariableMessage req;
220201
req.set_varname(BATCH_BARRIER_MESSAGE);
221202
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
222203
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
223204
req_count_++;
205+
return h;
224206
}
225207

226-
void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
227-
int64_t time_out) {
208+
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
209+
int64_t time_out) {
228210
const auto ch = GetChannel(ep);
229211
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
230-
s->Prepare(time_out);
212+
VarHandlePtr h(new VarHandle(ep, "FetchBarrier", FETCH_BARRIER_MESSAGE,
213+
nullptr, nullptr));
214+
s->Prepare(h, time_out);
231215

232216
sendrecv::VariableMessage req;
233217
req.set_varname(FETCH_BARRIER_MESSAGE);
234218
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
235219
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
236220
req_count_++;
221+
return h;
237222
}
238223

239-
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
224+
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
225+
int64_t time_out) {
240226
const auto ch = GetChannel(ep);
241227

242228
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
243-
s->Prepare(time_out);
229+
VarHandlePtr h(
230+
new VarHandle(ep, "SendComplete", COMPLETE_MESSAGE, nullptr, nullptr));
231+
s->Prepare(h, time_out);
244232

245233
sendrecv::VariableMessage req;
246234
req.set_varname(COMPLETE_MESSAGE);
247235
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
248236
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
249237
req_count_++;
238+
return h;
250239
}
251240

252-
void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
253-
const std::string& dir,
254-
int64_t time_out) {
241+
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
242+
const std::string& dir,
243+
int64_t time_out) {
255244
const auto ch = GetChannel(ep);
256245

257246
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
258-
s->Prepare(time_out);
247+
VarHandlePtr h(new VarHandle(ep, "CheckPointNotify", CHECKPOINT_SAVE_MESSAGE,
248+
nullptr, nullptr));
249+
s->Prepare(h, time_out);
259250

260251
sendrecv::VariableMessage req;
261252
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
@@ -264,6 +255,7 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
264255
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
265256
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
266257
req_count_++;
258+
return h;
267259
}
268260

269261
bool GRPCClient::Wait() {
@@ -276,32 +268,36 @@ void GRPCClient::Proceed() {
276268
void* tag = nullptr;
277269
bool ok = false;
278270

271+
VLOG(3) << "GRPCClient Proceed begin";
279272
while (!stopped_ && cq_.Next(&tag, &ok)) {
280273
BaseProcessor* c = static_cast<BaseProcessor*>(tag);
281274
GPR_ASSERT(ok);
282275
PADDLE_ENFORCE(c);
283276
if (c->status_.ok()) {
284-
VLOG(3) << c->var_h_.String() << " process";
277+
VLOG(3) << c->GetVarHandlePtr()->String() << " process";
285278
c->Process();
286279
} else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
287-
LOG(ERROR) << c->var_h_.String()
280+
LOG(ERROR) << c->GetVarHandlePtr()->String()
288281
<< " meets grpc error:" << c->status_.error_message();
289282
{
290283
std::lock_guard<std::mutex> lk(sync_mutex_);
291284
ok_ = false;
292285
}
293-
sync_cond_.notify_all();
286+
c->Finish(false);
294287
} else {
295-
LOG(FATAL) << c->var_h_.String()
288+
LOG(FATAL) << c->GetVarHandlePtr()->String()
296289
<< " meets grpc error:" << c->status_.error_message();
290+
c->Finish(false);
297291
}
292+
298293
delete c;
299294
{
300295
std::lock_guard<std::mutex> lk(sync_mutex_);
301296
req_count_--;
302297
}
303298
sync_cond_.notify_all();
304299
}
300+
VLOG(3) << "GRPCClient Proceed end";
305301
}
306302

307303
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {

0 commit comments

Comments
 (0)