Skip to content

Commit faebadd

Browse files
authored
Merge pull request #10228 from jacquesqiao/use-multi-thread-todo-update
Use multi thread to do update
2 parents ff99d94 + d86626d commit faebadd

File tree

4 files changed

+61
-36
lines changed

4 files changed

+61
-36
lines changed

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ class RequestSend final : public RequestBase {
8282
virtual std::string GetReqName() { return request_->Varname(); }
8383

8484
virtual void Process() {
85-
queue_->Push(std::make_pair(request_->Varname(), request_));
85+
std::string var_name = GetReqName();
86+
VLOG(3) << "RequestSend " << var_name;
87+
queue_->Push(std::make_pair(var_name, request_));
8688

8789
sendrecv::VoidMessage reply;
8890
responder_.Finish(reply, ::grpc::Status::OK, this);
@@ -106,7 +108,7 @@ class RequestGet final : public RequestBase {
106108
responder_(&ctx_),
107109
scope_(scope),
108110
queue_(queue) {
109-
int method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
111+
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
110112
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
111113
cq_, this);
112114
}
@@ -118,6 +120,7 @@ class RequestGet final : public RequestBase {
118120
virtual void Process() {
119121
// proc request.
120122
std::string var_name = request_.varname();
123+
VLOG(3) << "RequestGet " << var_name;
121124
auto* var = scope_->FindVar(var_name);
122125

123126
::grpc::ByteBuffer reply;
@@ -176,7 +179,7 @@ class RequestPrefetch final : public RequestBase {
176179
::grpc::ByteBuffer reply;
177180

178181
std::string var_name = request_->OutVarname();
179-
VLOG(3) << "prefetch var " << var_name;
182+
VLOG(3) << "RequestPrefetch " << var_name;
180183
auto var_desc = program_->Block(0).FindVar(var_name);
181184
framework::Scope* local_scope = &scope_->NewScope();
182185
auto* var = local_scope->FindVar(var_name);
@@ -307,18 +310,20 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
307310
bool ok = false;
308311

309312
while (true) {
310-
VLOG(3) << "HandleRequest for " << cq_name << " while in";
313+
VLOG(3) << "HandleRequest for " << cq_name << " wait Next";
311314
if (!cq->Next(&tag, &ok)) {
312315
LOG(INFO) << cq_name << " CompletionQueue shutdown!";
313316
break;
314317
}
315-
VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
318+
VLOG(3) << "HandleRequest for " << cq_name << " get Next";
316319

317320
PADDLE_ENFORCE(tag);
321+
318322
if (sync_mode_) {
319323
// FIXME(typhoonzero): de-couple the barriers with recv_op
320324
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
321325
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
326+
VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond";
322327
}
323328

324329
RequestBase* base = reinterpret_cast<RequestBase*>(tag);
@@ -336,9 +341,9 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
336341

337342
switch (base->Status()) {
338343
case PROCESS: {
339-
VLOG(4) << cq_name << " PROCESS status:" << base->Status();
340344
TryToRegisterNewOne();
341345
base->Process();
346+
VLOG(4) << cq_name << " PROCESS status:" << base->Status();
342347
break;
343348
}
344349
case FINISH: {

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,6 @@ static void split(const std::string &str, char sep,
4545
}
4646
}
4747

48-
static void AsyncExecuteBlock(framework::Executor *executor,
49-
framework::ExecutorPrepareContext *prepared,
50-
framework::Scope *scope) {
51-
std::future<void> future = framework::Async([&executor, &prepared, &scope]() {
52-
try {
53-
executor->RunPreparedContext(prepared, scope, false, false);
54-
} catch (std::exception &e) {
55-
LOG(ERROR) << "run sub program error " << e.what();
56-
}
57-
});
58-
// TODO(qiao) maybe we can remove this
59-
future.wait();
60-
}
61-
6248
static void ParallelExecuteBlocks(
6349
const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
6450
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
@@ -201,14 +187,40 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
201187
} // while(true)
202188
}
203189

190+
static void AsyncUpdateThread(
191+
const std::string &var_name, const bool &exit_flag,
192+
const std::shared_ptr<detail::ReceivedQueue> &queue,
193+
framework::Executor *executor,
194+
framework::ExecutorPrepareContext *prepared) {
195+
VLOG(3) << "update thread for " << var_name << " started";
196+
while (!exit_flag) {
197+
const detail::ReceivedMessage v = queue->Pop();
198+
auto recv_var_name = v.first;
199+
auto var = v.second->GetVar();
200+
if (var == nullptr) {
201+
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
202+
PADDLE_THROW("Can not find server side var");
203+
}
204+
auto fs = framework::Async([var_name, &executor, &v, prepared] {
205+
try {
206+
executor->RunPreparedContext(prepared, v.second->GetMutableLocalScope(),
207+
false, false);
208+
} catch (std::exception &e) {
209+
LOG(ERROR) << "run sub program error " << e.what();
210+
}
211+
});
212+
fs.wait();
213+
}
214+
}
215+
204216
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
205-
framework::ProgramDesc *program,
206-
framework::Scope *recv_scope,
207-
framework::BlockDesc *prefetch_block) const {
217+
framework::ProgramDesc *program) const {
208218
VLOG(3) << "RunAsyncLoop in";
209219
// grad name to block id
210220
std::unordered_map<std::string, int32_t> grad_to_block_id;
211221
std::unordered_map<int32_t, std::string> id_to_grad;
222+
std::unordered_map<std::string, std::shared_ptr<detail::ReceivedQueue>>
223+
grad_to_queue;
212224

213225
auto grad_to_block_id_str =
214226
Attr<std::vector<std::string>>("grad_to_block_id");
@@ -220,6 +232,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
220232
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
221233
int block_id = std::stoi(pieces[1]);
222234
grad_to_block_id[pieces[0]] = block_id;
235+
grad_to_queue[pieces[0]] = std::make_shared<detail::ReceivedQueue>();
223236
id_to_grad[block_id] = pieces[0];
224237
}
225238
size_t num_blocks = program->Size();
@@ -238,8 +251,21 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
238251
grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i];
239252
}
240253

241-
VLOG(3) << "RunAsyncLoop into while";
242254
bool exit_flag = false;
255+
256+
VLOG(3) << "start async optimize threads";
257+
std::vector<std::future<void>> fs;
258+
for (auto iter = grad_to_queue.begin(); iter != grad_to_queue.end(); iter++) {
259+
std::string grad_name = iter->first;
260+
VLOG(3) << "create async update thread for " << grad_name;
261+
fs.push_back(framework::AsyncIO([grad_name, &exit_flag, &executor,
262+
&grad_to_queue, &grad_to_prepared_ctx]() {
263+
AsyncUpdateThread(grad_name, exit_flag, grad_to_queue[grad_name],
264+
executor, grad_to_prepared_ctx[grad_name].get());
265+
}));
266+
}
267+
268+
VLOG(3) << "RunAsyncLoop into while";
243269
while (!exit_flag) {
244270
const detail::ReceivedMessage v = rpc_service_->Get();
245271
auto recv_var_name = v.first;
@@ -249,13 +275,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
249275
break;
250276
} else {
251277
VLOG(3) << "received grad: " << recv_var_name;
252-
auto var = v.second->GetVar();
253-
if (var == nullptr) {
254-
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
255-
PADDLE_THROW("Can not find server side var");
256-
}
257-
AsyncExecuteBlock(executor, grad_to_prepared_ctx[recv_var_name].get(),
258-
v.second->GetMutableLocalScope());
278+
grad_to_queue[recv_var_name]->Push(v);
259279
}
260280

261281
if (exit_flag) {
@@ -304,7 +324,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
304324
if (sync_mode) {
305325
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
306326
} else {
307-
RunAsyncLoop(&executor, program, &recv_scope, prefetch_block);
327+
RunAsyncLoop(&executor, program);
308328
}
309329
}
310330

paddle/fluid/operators/listen_and_serv_op.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ class ListenAndServOp : public framework::OperatorBase {
4747
framework::BlockDesc* prefetch_block) const;
4848

4949
void RunAsyncLoop(framework::Executor* executor,
50-
framework::ProgramDesc* program,
51-
framework::Scope* recv_scope,
52-
framework::BlockDesc* prefetch_block) const;
50+
framework::ProgramDesc* program) const;
5351

5452
void Stop() override;
5553

python/paddle/fluid/layers/io.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,9 @@ def complete_op(self):
168168
'endpoint': self.endpoint,
169169
'Fanin': self.fan_in,
170170
'OptimizeBlock': current_block,
171-
'PrefetchBlock': empty_block
171+
'PrefetchBlock': empty_block,
172+
'sync_mode': True, # did not support async now in layers
173+
'grad_to_block_id': [""]
172174
})
173175

174176

0 commit comments

Comments
 (0)