@@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH };
30
30
class RequestBase {
31
31
public:
32
32
explicit RequestBase (GrpcService::AsyncService* service,
33
- ::grpc::ServerCompletionQueue* cq,
33
+ ::grpc::ServerCompletionQueue* cq, bool sync_mode,
34
34
const platform::DeviceContext* dev_ctx)
35
- : service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) {
35
+ : service_(service),
36
+ cq_(cq),
37
+ sync_mode_(sync_mode),
38
+ status_(PROCESS),
39
+ dev_ctx_(dev_ctx) {
36
40
PADDLE_ENFORCE (cq_);
37
41
}
38
42
virtual ~RequestBase () {}
@@ -49,18 +53,25 @@ class RequestBase {
49
53
::grpc::ServerContext ctx_;
50
54
GrpcService::AsyncService* service_;
51
55
::grpc::ServerCompletionQueue* cq_;
56
+ const bool sync_mode_;
52
57
CallStatus status_;
53
58
const platform::DeviceContext* dev_ctx_;
54
59
};
55
60
56
61
class RequestSend final : public RequestBase {
57
62
public:
58
63
explicit RequestSend (GrpcService::AsyncService* service,
59
- ::grpc::ServerCompletionQueue* cq,
64
+ ::grpc::ServerCompletionQueue* cq, bool sync_mode,
60
65
framework::Scope* scope, ReceivedQueue* queue,
61
66
const platform::DeviceContext* dev_ctx)
62
- : RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) {
63
- request_.reset (new VariableResponse (scope, dev_ctx_));
67
+ : RequestBase(service, cq, sync_mode, dev_ctx),
68
+ queue_(queue),
69
+ responder_(&ctx_) {
70
+ if (sync_mode_) {
71
+ request_.reset (new VariableResponse (scope, dev_ctx_, false ));
72
+ } else {
73
+ request_.reset (new VariableResponse (scope, dev_ctx_, true ));
74
+ }
64
75
int method_id = static_cast <int >(detail::GrpcMethod::kSendVariable );
65
76
service_->RequestAsyncUnary (method_id, &ctx_, request_.get (), &responder_,
66
77
cq_, cq_, this );
@@ -87,11 +98,11 @@ class RequestSend final : public RequestBase {
87
98
class RequestGet final : public RequestBase {
88
99
public:
89
100
explicit RequestGet (GrpcService::AsyncService* service,
90
- ::grpc::ServerCompletionQueue* cq,
101
+ ::grpc::ServerCompletionQueue* cq, bool sync_mode,
91
102
framework::Scope* scope,
92
103
const platform::DeviceContext* dev_ctx,
93
104
framework::BlockingQueue<MessageWithName>* queue)
94
- : RequestBase(service, cq, dev_ctx),
105
+ : RequestBase(service, cq, sync_mode, dev_ctx),
95
106
responder_(&ctx_),
96
107
scope_(scope),
97
108
queue_(queue) {
@@ -134,19 +145,23 @@ class RequestGet final : public RequestBase {
134
145
class RequestPrefetch final : public RequestBase {
135
146
public:
136
147
explicit RequestPrefetch (GrpcService::AsyncService* service,
137
- ::grpc::ServerCompletionQueue* cq,
148
+ ::grpc::ServerCompletionQueue* cq, bool sync_mode,
138
149
framework::Scope* scope,
139
150
const platform::DeviceContext* dev_ctx,
140
151
framework::Executor* executor,
141
152
framework::ProgramDesc* program,
142
153
framework::ExecutorPrepareContext* prefetch_ctx)
143
- : RequestBase(service, cq, dev_ctx),
154
+ : RequestBase(service, cq, sync_mode, dev_ctx),
144
155
responder_(&ctx_),
145
156
scope_(scope),
146
157
executor_(executor),
147
158
program_(program),
148
159
prefetch_ctx_(prefetch_ctx) {
149
- request_.reset (new VariableResponse (scope, dev_ctx_));
160
+ if (sync_mode_) {
161
+ request_.reset (new VariableResponse (scope, dev_ctx_, false ));
162
+ } else {
163
+ request_.reset (new VariableResponse (scope, dev_ctx_, true ));
164
+ }
150
165
int method_id = static_cast <int >(detail::GrpcMethod::kPrefetchVariable );
151
166
service_->RequestAsyncUnary (method_id, &ctx_, request_.get (), &responder_,
152
167
cq_, cq_, this );
@@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase {
181
196
framework::Executor* executor_;
182
197
framework::ProgramDesc* program_;
183
198
framework::ExecutorPrepareContext* prefetch_ctx_;
184
- int blkid_;
185
199
};
186
200
187
201
void AsyncGRPCServer::WaitClientGet (int count) {
@@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
254
268
VLOG (3 ) << " shutdown, do not TryToRegisterNewSendOne" ;
255
269
return ;
256
270
}
257
- RequestSend* send = new RequestSend (&service_, cq_send_.get (), scope_ ,
258
- &var_recv_queue_, dev_ctx_);
271
+ RequestSend* send = new RequestSend (&service_, cq_send_.get (), sync_mode_ ,
272
+ scope_, &var_recv_queue_, dev_ctx_);
259
273
VLOG (4 ) << " Create RequestSend status:" << send->Status ();
260
274
}
261
275
@@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
265
279
VLOG (3 ) << " shutdown, do not TryToRegisterNewGetOne" ;
266
280
return ;
267
281
}
268
- RequestGet* get = new RequestGet (&service_, cq_get_.get (), scope_, dev_ctx_ ,
269
- &var_get_queue_);
282
+ RequestGet* get = new RequestGet (&service_, cq_get_.get (), sync_mode_, scope_ ,
283
+ dev_ctx_, &var_get_queue_);
270
284
VLOG (4 ) << " Create RequestGet status:" << get->Status ();
271
285
}
272
286
@@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
277
291
return ;
278
292
}
279
293
RequestPrefetch* prefetch =
280
- new RequestPrefetch (&service_, cq_prefetch_.get (), scope_, dev_ctx_ ,
281
- executor_, program_, prefetch_ctx_);
294
+ new RequestPrefetch (&service_, cq_prefetch_.get (), sync_mode_, scope_ ,
295
+ dev_ctx_, executor_, program_, prefetch_ctx_);
282
296
283
297
VLOG (4 ) << " Create RequestPrefetch status:" << prefetch->Status ();
284
298
}
@@ -301,9 +315,11 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
301
315
VLOG (3 ) << " HandleRequest for " << cq_name << " while after Next" ;
302
316
303
317
PADDLE_ENFORCE (tag);
304
- // FIXME(typhoonzero): de-couple the barriers with recv_op
305
- if (!is_shut_down_ && cq_name == " cq_get" ) WaitCond (1 );
306
- if (!is_shut_down_ && cq_name == " cq_send" ) WaitCond (0 );
318
+ if (sync_mode_) {
319
+ // FIXME(typhoonzero): de-couple the barriers with recv_op
320
+ if (!is_shut_down_ && cq_name == " cq_get" ) WaitCond (1 );
321
+ if (!is_shut_down_ && cq_name == " cq_send" ) WaitCond (0 );
322
+ }
307
323
308
324
RequestBase* base = reinterpret_cast <RequestBase*>(tag);
309
325
// reference:
@@ -320,13 +336,13 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
320
336
321
337
switch (base->Status ()) {
322
338
case PROCESS: {
323
- VLOG (4 ) << cq_name << " status:" << base->Status ();
339
+ VLOG (4 ) << cq_name << " PROCESS status:" << base->Status ();
324
340
TryToRegisterNewOne ();
325
341
base->Process ();
326
342
break ;
327
343
}
328
344
case FINISH: {
329
- VLOG (4 ) << cq_name << " status:" << base->Status ();
345
+ VLOG (4 ) << cq_name << " FINISH status:" << base->Status ();
330
346
delete base;
331
347
break ;
332
348
}
0 commit comments