Skip to content

Commit 6d93456

Browse files
authored
Merge pull request #10042 from jacquesqiao/add-async-listen-and-serv-op
listen_and_serv_op support async update
2 parents f457d5d + 3295f31 commit 6d93456

File tree

9 files changed

+220
-62
lines changed

9 files changed

+220
-62
lines changed

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH };
3030
class RequestBase {
3131
public:
3232
explicit RequestBase(GrpcService::AsyncService* service,
33-
::grpc::ServerCompletionQueue* cq,
33+
::grpc::ServerCompletionQueue* cq, bool sync_mode,
3434
const platform::DeviceContext* dev_ctx)
35-
: service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) {
35+
: service_(service),
36+
cq_(cq),
37+
sync_mode_(sync_mode),
38+
status_(PROCESS),
39+
dev_ctx_(dev_ctx) {
3640
PADDLE_ENFORCE(cq_);
3741
}
3842
virtual ~RequestBase() {}
@@ -49,18 +53,25 @@ class RequestBase {
4953
::grpc::ServerContext ctx_;
5054
GrpcService::AsyncService* service_;
5155
::grpc::ServerCompletionQueue* cq_;
56+
const bool sync_mode_;
5257
CallStatus status_;
5358
const platform::DeviceContext* dev_ctx_;
5459
};
5560

5661
class RequestSend final : public RequestBase {
5762
public:
5863
explicit RequestSend(GrpcService::AsyncService* service,
59-
::grpc::ServerCompletionQueue* cq,
64+
::grpc::ServerCompletionQueue* cq, bool sync_mode,
6065
framework::Scope* scope, ReceivedQueue* queue,
6166
const platform::DeviceContext* dev_ctx)
62-
: RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) {
63-
request_.reset(new VariableResponse(scope, dev_ctx_));
67+
: RequestBase(service, cq, sync_mode, dev_ctx),
68+
queue_(queue),
69+
responder_(&ctx_) {
70+
if (sync_mode_) {
71+
request_.reset(new VariableResponse(scope, dev_ctx_, false));
72+
} else {
73+
request_.reset(new VariableResponse(scope, dev_ctx_, true));
74+
}
6475
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
6576
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
6677
cq_, cq_, this);
@@ -87,11 +98,11 @@ class RequestSend final : public RequestBase {
8798
class RequestGet final : public RequestBase {
8899
public:
89100
explicit RequestGet(GrpcService::AsyncService* service,
90-
::grpc::ServerCompletionQueue* cq,
101+
::grpc::ServerCompletionQueue* cq, bool sync_mode,
91102
framework::Scope* scope,
92103
const platform::DeviceContext* dev_ctx,
93104
framework::BlockingQueue<MessageWithName>* queue)
94-
: RequestBase(service, cq, dev_ctx),
105+
: RequestBase(service, cq, sync_mode, dev_ctx),
95106
responder_(&ctx_),
96107
scope_(scope),
97108
queue_(queue) {
@@ -134,19 +145,23 @@ class RequestGet final : public RequestBase {
134145
class RequestPrefetch final : public RequestBase {
135146
public:
136147
explicit RequestPrefetch(GrpcService::AsyncService* service,
137-
::grpc::ServerCompletionQueue* cq,
148+
::grpc::ServerCompletionQueue* cq, bool sync_mode,
138149
framework::Scope* scope,
139150
const platform::DeviceContext* dev_ctx,
140151
framework::Executor* executor,
141152
framework::ProgramDesc* program,
142153
framework::ExecutorPrepareContext* prefetch_ctx)
143-
: RequestBase(service, cq, dev_ctx),
154+
: RequestBase(service, cq, sync_mode, dev_ctx),
144155
responder_(&ctx_),
145156
scope_(scope),
146157
executor_(executor),
147158
program_(program),
148159
prefetch_ctx_(prefetch_ctx) {
149-
request_.reset(new VariableResponse(scope, dev_ctx_));
160+
if (sync_mode_) {
161+
request_.reset(new VariableResponse(scope, dev_ctx_, false));
162+
} else {
163+
request_.reset(new VariableResponse(scope, dev_ctx_, true));
164+
}
150165
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
151166
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
152167
cq_, cq_, this);
@@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase {
181196
framework::Executor* executor_;
182197
framework::ProgramDesc* program_;
183198
framework::ExecutorPrepareContext* prefetch_ctx_;
184-
int blkid_;
185199
};
186200

187201
void AsyncGRPCServer::WaitClientGet(int count) {
@@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
254268
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
255269
return;
256270
}
257-
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
258-
&var_recv_queue_, dev_ctx_);
271+
RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_,
272+
scope_, &var_recv_queue_, dev_ctx_);
259273
VLOG(4) << "Create RequestSend status:" << send->Status();
260274
}
261275

@@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
265279
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
266280
return;
267281
}
268-
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
269-
&var_get_queue_);
282+
RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
283+
dev_ctx_, &var_get_queue_);
270284
VLOG(4) << "Create RequestGet status:" << get->Status();
271285
}
272286

@@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
277291
return;
278292
}
279293
RequestPrefetch* prefetch =
280-
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_,
281-
executor_, program_, prefetch_ctx_);
294+
new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_,
295+
dev_ctx_, executor_, program_, prefetch_ctx_);
282296

283297
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
284298
}
@@ -301,9 +315,11 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
301315
VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
302316

303317
PADDLE_ENFORCE(tag);
304-
// FIXME(typhoonzero): de-couple the barriers with recv_op
305-
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
306-
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
318+
if (sync_mode_) {
319+
// FIXME(typhoonzero): de-couple the barriers with recv_op
320+
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
321+
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
322+
}
307323

308324
RequestBase* base = reinterpret_cast<RequestBase*>(tag);
309325
// reference:
@@ -320,13 +336,13 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
320336

321337
switch (base->Status()) {
322338
case PROCESS: {
323-
VLOG(4) << cq_name << " status:" << base->Status();
339+
VLOG(4) << cq_name << " PROCESS status:" << base->Status();
324340
TryToRegisterNewOne();
325341
base->Process();
326342
break;
327343
}
328344
case FINISH: {
329-
VLOG(4) << cq_name << " status:" << base->Status();
345+
VLOG(4) << cq_name << " FINISH status:" << base->Status();
330346
delete base;
331347
break;
332348
}

paddle/fluid/operators/detail/grpc_server.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ class RequestBase;
4444

4545
class AsyncGRPCServer final {
4646
public:
47-
explicit AsyncGRPCServer(const std::string &address) : address_(address) {}
47+
explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
48+
: address_(address), sync_mode_(sync_mode) {}
4849

4950
void RunSyncUpdate();
5051

@@ -95,6 +96,7 @@ class AsyncGRPCServer final {
9596
std::unique_ptr<::grpc::Server> server_;
9697

9798
std::string address_;
99+
const bool sync_mode_;
98100
framework::Scope *scope_;
99101
const platform::DeviceContext *dev_ctx_;
100102

paddle/fluid/operators/detail/grpc_server_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
8989
}
9090

9191
void StartServer(const std::string& endpoint) {
92-
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
92+
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, true));
9393
framework::ProgramDesc program;
9494
framework::Scope scope;
9595
platform::CPUPlace place;

paddle/fluid/operators/detail/variable_response.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ class VariableResponse {
4646
}
4747

4848
virtual ~VariableResponse() {
49-
if (create_scope_) scope_->DeleteScope(local_scope_);
49+
if (create_scope_) {
50+
scope_->DeleteScope(local_scope_);
51+
}
5052
}
5153

5254
// return:
@@ -63,6 +65,8 @@ class VariableResponse {
6365

6466
const framework::Scope& GetLocalScope() const { return *local_scope_; }
6567

68+
framework::Scope* GetMutableLocalScope() const { return local_scope_; }
69+
6670
inline std::string Varname() { return meta_.varname(); }
6771
inline std::string OutVarname() { return meta_.out_varname(); }
6872

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,38 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
2727
VLOG(4) << "RunServer thread end";
2828
}
2929

30+
static void split(const std::string &str, char sep,
31+
std::vector<std::string> *pieces) {
32+
pieces->clear();
33+
if (str.empty()) {
34+
return;
35+
}
36+
size_t pos = 0;
37+
size_t next = str.find(sep, pos);
38+
while (next != std::string::npos) {
39+
pieces->push_back(str.substr(pos, next - pos));
40+
pos = next + 1;
41+
next = str.find(sep, pos);
42+
}
43+
if (!str.substr(pos).empty()) {
44+
pieces->push_back(str.substr(pos));
45+
}
46+
}
47+
48+
static void AsyncExecuteBlock(framework::Executor *executor,
49+
framework::ExecutorPrepareContext *prepared,
50+
framework::Scope *scope) {
51+
std::future<void> future = framework::Async([&executor, &prepared, &scope]() {
52+
try {
53+
executor->RunPreparedContext(prepared, scope, false, false);
54+
} catch (std::exception &e) {
55+
LOG(ERROR) << "run sub program error " << e.what();
56+
}
57+
});
58+
// TODO(qiao) maybe we can remove this
59+
future.wait();
60+
}
61+
3062
static void ParallelExecuteBlocks(
3163
const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
3264
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
@@ -169,15 +201,82 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
169201
} // while(true)
170202
}
171203

204+
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
205+
framework::ProgramDesc *program,
206+
framework::Scope *recv_scope,
207+
framework::BlockDesc *prefetch_block) const {
208+
VLOG(3) << "RunAsyncLoop in";
209+
// grad name to block id
210+
std::unordered_map<std::string, int32_t> grad_to_block_id;
211+
std::unordered_map<int32_t, std::string> id_to_grad;
212+
213+
auto grad_to_block_id_str =
214+
Attr<std::vector<std::string>>("grad_to_block_id");
215+
for (auto &grad_and_id : grad_to_block_id_str) {
216+
std::vector<std::string> pieces;
217+
split(grad_and_id, ':', &pieces);
218+
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
219+
PADDLE_ENFORCE_EQ(pieces.size(), 2);
220+
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
221+
int block_id = std::stoi(pieces[1]);
222+
grad_to_block_id[pieces[0]] = block_id;
223+
id_to_grad[block_id] = pieces[0];
224+
}
225+
size_t num_blocks = program->Size();
226+
PADDLE_ENFORCE_GE(num_blocks, 2,
227+
"server program should have at least 2 blocks");
228+
229+
std::vector<int> block_list;
230+
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
231+
block_list.push_back(blkid);
232+
}
233+
auto optimize_prepared = executor->Prepare(*program, block_list);
234+
std::unordered_map<std::string,
235+
std::shared_ptr<framework::ExecutorPrepareContext>>
236+
grad_to_prepared_ctx;
237+
for (size_t i = 0; i < block_list.size(); ++i) {
238+
grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i];
239+
}
240+
241+
VLOG(3) << "RunAsyncLoop into while";
242+
bool exit_flag = false;
243+
while (!exit_flag) {
244+
const detail::ReceivedMessage v = rpc_service_->Get();
245+
auto recv_var_name = v.first;
246+
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
247+
LOG(INFO) << "received terminate message and exit";
248+
exit_flag = true;
249+
break;
250+
} else {
251+
VLOG(3) << "received grad: " << recv_var_name;
252+
auto var = v.second->GetVar();
253+
if (var == nullptr) {
254+
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
255+
PADDLE_THROW("Can not find server side var");
256+
}
257+
AsyncExecuteBlock(executor, grad_to_prepared_ctx[recv_var_name].get(),
258+
v.second->GetMutableLocalScope());
259+
}
260+
261+
if (exit_flag) {
262+
rpc_service_->ShutDown();
263+
break;
264+
}
265+
} // while(true)
266+
}
267+
172268
void ListenAndServOp::RunImpl(const framework::Scope &scope,
173269
const platform::Place &dev_place) const {
174270
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
175271
auto &dev_ctx = *pool.Get(dev_place);
176272
framework::Scope &recv_scope = scope.NewScope();
177273

274+
bool sync_mode = Attr<bool>("sync_mode");
275+
178276
PADDLE_ENFORCE(!rpc_service_);
179277
std::string endpoint = Attr<std::string>("endpoint");
180-
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
278+
279+
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, sync_mode));
181280

182281
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
183282
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
@@ -202,7 +301,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
202301
sleep(5);
203302
// Write to a file of server selected port for python use.
204303
SavePort(rpc_service_);
205-
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
304+
if (sync_mode) {
305+
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
306+
} else {
307+
RunAsyncLoop(&executor, program, &recv_scope, prefetch_block);
308+
}
206309
}
207310

208311
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -221,6 +324,12 @@ from send_op and send back variables to recv_op.
221324
"IP address to listen on.")
222325
.SetDefault("127.0.0.1:6164")
223326
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
327+
AddAttr<std::vector<std::string>>(
328+
"grad_to_block_id",
329+
330+
"a map from grad name to it's optimize block id")
331+
.SetDefault({});
332+
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
224333
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
225334
"BlockID to run on server side.");
226335
AddAttr<framework::BlockDesc *>(kPrefetchBlock,

paddle/fluid/operators/listen_and_serv_op.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ class ListenAndServOp : public framework::OperatorBase {
4646
framework::Scope* recv_scope,
4747
framework::BlockDesc* prefetch_block) const;
4848

49+
void RunAsyncLoop(framework::Executor* executor,
50+
framework::ProgramDesc* program,
51+
framework::Scope* recv_scope,
52+
framework::BlockDesc* prefetch_block) const;
53+
4954
void Stop() override;
5055

5156
void RunImpl(const framework::Scope& scope,

0 commit comments

Comments
 (0)