@@ -59,40 +59,32 @@ GRPCClient::~GRPCClient() {
59
59
}
60
60
channels_.clear ();
61
61
}
62
-
63
62
client_thread_->join ();
64
63
}
65
64
66
- bool GRPCClient::AsyncSendVar (const std::string& ep,
67
- const platform::DeviceContext& ctx,
68
- const framework::Scope& scope,
69
- const std::string& var_name, int64_t time_out) {
65
+ VarHandlePtr GRPCClient::AsyncSendVar (const std::string& ep,
66
+ const platform::DeviceContext& ctx,
67
+ const framework::Scope& scope,
68
+ const std::string& var_name,
69
+ int64_t time_out) {
70
70
const platform::DeviceContext* p_ctx = &ctx;
71
71
const std::string ep_val = ep;
72
72
const std::string var_name_val = var_name;
73
73
const framework::Scope* p_scope = &scope;
74
74
const auto ch = GetChannel (ep_val);
75
+ SendProcessor* s = new SendProcessor (ch);
76
+ VarHandlePtr h (new VarHandle (ep, " Send" , var_name_val, p_ctx, p_scope));
77
+ s->Prepare (h, time_out);
75
78
76
- framework::AsyncIO ([var_name_val, p_ctx, ep_val, p_scope, time_out, ch,
77
- this ] {
79
+ framework::AsyncIO ([var_name_val, p_scope, p_ctx, s, this ] {
78
80
auto * var = p_scope->FindVar (var_name_val);
79
81
80
82
::grpc::ByteBuffer req;
81
83
SerializeToByteBuffer (var_name_val, var, *p_ctx, &req);
82
84
83
- // varhandle
84
- VarHandle var_h;
85
- var_h.ep = ep_val;
86
- var_h.scope = p_scope;
87
- var_h.name = var_name_val;
88
- var_h.ctx = p_ctx;
89
- var_h.method = " Send" ;
90
-
91
- VLOG (3 ) << var_h.String () << " begin" ;
85
+ VLOG (3 ) << s->GetVarHandlePtr ()->String () << " begin" ;
92
86
93
87
// stub context
94
- SendProcessor* s = new SendProcessor (ch);
95
- s->Prepare (var_h, time_out);
96
88
s->response_call_back_ = nullptr ;
97
89
98
90
auto call = s->stub_g_ .PrepareUnaryCall (
@@ -102,13 +94,13 @@ bool GRPCClient::AsyncSendVar(const std::string& ep,
102
94
});
103
95
req_count_++;
104
96
105
- return true ;
97
+ return h ;
106
98
}
107
99
108
100
void ProcGetResponse (const VarHandle& var_h,
109
101
const ::grpc::ByteBuffer& ret_msg) {
110
102
framework::Variable* outvar = nullptr ;
111
- DeserializeFromByteBuffer (ret_msg, *var_h.ctx , var_h.scope , &outvar);
103
+ DeserializeFromByteBuffer (ret_msg, *var_h.ctx () , var_h.scope () , &outvar);
112
104
}
113
105
114
106
template <typename T>
@@ -119,37 +111,30 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
119
111
result->Swap (&tmp);
120
112
}
121
113
122
- bool GRPCClient::AsyncGetVar (const std::string& ep,
123
- const platform::DeviceContext& ctx,
124
- const framework::Scope& scope,
125
- const std::string& var_name, int64_t time_out) {
114
+ VarHandlePtr GRPCClient::AsyncGetVar (const std::string& ep,
115
+ const platform::DeviceContext& ctx,
116
+ const framework::Scope& scope,
117
+ const std::string& var_name,
118
+ int64_t time_out) {
126
119
const platform::DeviceContext* p_ctx = &ctx;
127
120
const std::string ep_val = ep;
128
121
const std::string var_name_val = var_name;
129
122
const framework::Scope* p_scope = &scope;
130
123
const auto ch = GetChannel (ep_val);
124
+ GetProcessor* s = new GetProcessor (ch);
125
+ VarHandlePtr h (new VarHandle (ep, " Get" , var_name_val, p_ctx, p_scope));
126
+ s->Prepare (h, time_out);
131
127
132
- framework::AsyncIO ([var_name_val, ep_val, p_scope, p_ctx, time_out, ch,
133
- this ] {
128
+ framework::AsyncIO ([var_name_val, p_scope, p_ctx, s, this ] {
134
129
// prepare input
135
130
sendrecv::VariableMessage req;
136
131
req.set_varname (var_name_val);
137
132
::grpc::ByteBuffer buf;
138
133
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
139
134
140
- // var handle
141
- VarHandle var_h;
142
- var_h.ep = ep_val;
143
- var_h.scope = p_scope;
144
- var_h.name = var_name_val;
145
- var_h.ctx = p_ctx;
146
- var_h.method = " Get" ;
147
-
148
- VLOG (3 ) << var_h.String () << " begin" ;
135
+ VLOG (3 ) << s->GetVarHandlePtr ()->String () << " begin" ;
149
136
150
137
// stub context
151
- GetProcessor* s = new GetProcessor (ch);
152
- s->Prepare (var_h, time_out);
153
138
s->response_call_back_ = ProcGetResponse;
154
139
155
140
auto call = s->stub_g_ .PrepareUnaryCall (
@@ -160,42 +145,36 @@ bool GRPCClient::AsyncGetVar(const std::string& ep,
160
145
161
146
req_count_++;
162
147
163
- return true ;
148
+ return h ;
164
149
}
165
150
166
- bool GRPCClient::AsyncPrefetchVar (const std::string& ep,
167
- const platform::DeviceContext& ctx,
168
- const framework::Scope& scope,
169
- const std::string& in_var_name,
170
- const std::string& out_var_name,
171
- int64_t time_out) {
151
+ VarHandlePtr GRPCClient::AsyncPrefetchVar (const std::string& ep,
152
+ const platform::DeviceContext& ctx,
153
+ const framework::Scope& scope,
154
+ const std::string& in_var_name,
155
+ const std::string& out_var_name,
156
+ int64_t time_out) {
172
157
const platform::DeviceContext* p_ctx = &ctx;
173
158
const std::string ep_val = ep;
174
159
const std::string in_var_name_val = in_var_name;
175
160
const std::string out_var_name_val = out_var_name;
176
161
const framework::Scope* p_scope = &scope;
177
162
const auto ch = GetChannel (ep_val);
163
+ GetProcessor* s = new GetProcessor (ch);
164
+ VarHandlePtr h (
165
+ new VarHandle (ep, " Prefetch" , out_var_name_val, p_ctx, p_scope));
166
+ s->Prepare (h, time_out);
178
167
179
168
framework::AsyncIO ([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
180
- time_out, ch , this ] {
169
+ time_out, s , this ] {
181
170
auto * var = p_scope->FindVar (in_var_name_val);
182
171
183
172
::grpc::ByteBuffer req;
184
173
SerializeToByteBuffer (in_var_name_val, var, *p_ctx, &req, out_var_name_val);
185
174
186
- // var handle
187
- VarHandle var_h;
188
- var_h.ep = ep_val;
189
- var_h.scope = p_scope;
190
- var_h.name = out_var_name_val;
191
- var_h.ctx = p_ctx;
192
- var_h.method = " Prefetch" ;
193
-
194
- VLOG (3 ) << var_h.String () << " begin" ;
175
+ VLOG (3 ) << s->GetVarHandlePtr ()->String () << " begin" ;
195
176
196
177
// stub context
197
- GetProcessor* s = new GetProcessor (ch);
198
- s->Prepare (var_h, time_out);
199
178
s->response_call_back_ = ProcGetResponse;
200
179
201
180
auto call = s->stub_g_ .PrepareUnaryCall (
@@ -206,56 +185,68 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep,
206
185
});
207
186
208
187
req_count_++;
209
- return true ;
188
+ return h ;
210
189
}
211
190
212
- void GRPCClient::AsyncSendBatchBarrier (const std::string& ep,
213
- int64_t time_out) {
191
+ VarHandlePtr GRPCClient::AsyncSendBatchBarrier (const std::string& ep,
192
+ int64_t time_out) {
214
193
const auto ch = GetChannel (ep);
215
194
216
195
BatchBarrierProcessor* s = new BatchBarrierProcessor (ch);
217
- s->Prepare (time_out);
196
+ VarHandlePtr h (new VarHandle (ep, " BatchBarrier" , BATCH_BARRIER_MESSAGE,
197
+ nullptr , nullptr ));
198
+ s->Prepare (h, time_out);
218
199
219
200
sendrecv::VariableMessage req;
220
201
req.set_varname (BATCH_BARRIER_MESSAGE);
221
202
auto rpc = s->stub_ ->AsyncSendVariable (s->context_ .get (), req, &cq_);
222
203
rpc->Finish (&s->reply_ , &s->status_ , reinterpret_cast <void *>(s));
223
204
req_count_++;
205
+ return h;
224
206
}
225
207
226
- void GRPCClient::AsyncSendFetchBarrier (const std::string& ep,
227
- int64_t time_out) {
208
+ VarHandlePtr GRPCClient::AsyncSendFetchBarrier (const std::string& ep,
209
+ int64_t time_out) {
228
210
const auto ch = GetChannel (ep);
229
211
FetchBarrierProcessor* s = new FetchBarrierProcessor (ch);
230
- s->Prepare (time_out);
212
+ VarHandlePtr h (new VarHandle (ep, " FetchBarrier" , FETCH_BARRIER_MESSAGE,
213
+ nullptr , nullptr ));
214
+ s->Prepare (h, time_out);
231
215
232
216
sendrecv::VariableMessage req;
233
217
req.set_varname (FETCH_BARRIER_MESSAGE);
234
218
auto rpc = s->stub_ ->AsyncGetVariable (s->context_ .get (), req, &cq_);
235
219
rpc->Finish (&s->reply_ , &s->status_ , reinterpret_cast <void *>(s));
236
220
req_count_++;
221
+ return h;
237
222
}
238
223
239
- void GRPCClient::AsyncSendComplete (const std::string& ep, int64_t time_out) {
224
+ VarHandlePtr GRPCClient::AsyncSendComplete (const std::string& ep,
225
+ int64_t time_out) {
240
226
const auto ch = GetChannel (ep);
241
227
242
228
BatchBarrierProcessor* s = new BatchBarrierProcessor (ch);
243
- s->Prepare (time_out);
229
+ VarHandlePtr h (
230
+ new VarHandle (ep, " SendComplete" , COMPLETE_MESSAGE, nullptr , nullptr ));
231
+ s->Prepare (h, time_out);
244
232
245
233
sendrecv::VariableMessage req;
246
234
req.set_varname (COMPLETE_MESSAGE);
247
235
auto rpc = s->stub_ ->AsyncSendVariable (s->context_ .get (), req, &cq_);
248
236
rpc->Finish (&s->reply_ , &s->status_ , reinterpret_cast <void *>(s));
249
237
req_count_++;
238
+ return h;
250
239
}
251
240
252
- void GRPCClient::AsyncCheckpointNotify (const std::string& ep,
253
- const std::string& dir,
254
- int64_t time_out) {
241
+ VarHandlePtr GRPCClient::AsyncCheckpointNotify (const std::string& ep,
242
+ const std::string& dir,
243
+ int64_t time_out) {
255
244
const auto ch = GetChannel (ep);
256
245
257
246
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor (ch);
258
- s->Prepare (time_out);
247
+ VarHandlePtr h (new VarHandle (ep, " CheckPointNotify" , CHECKPOINT_SAVE_MESSAGE,
248
+ nullptr , nullptr ));
249
+ s->Prepare (h, time_out);
259
250
260
251
sendrecv::VariableMessage req;
261
252
req.set_varname (CHECKPOINT_SAVE_MESSAGE);
@@ -264,6 +255,7 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
264
255
auto rpc = s->stub_ ->AsyncCheckpointNotify (s->context_ .get (), req, &cq_);
265
256
rpc->Finish (&s->reply_ , &s->status_ , reinterpret_cast <void *>(s));
266
257
req_count_++;
258
+ return h;
267
259
}
268
260
269
261
bool GRPCClient::Wait () {
@@ -276,32 +268,36 @@ void GRPCClient::Proceed() {
276
268
void * tag = nullptr ;
277
269
bool ok = false ;
278
270
271
+ VLOG (3 ) << " GRPCClient Proceed begin" ;
279
272
while (!stopped_ && cq_.Next (&tag, &ok)) {
280
273
BaseProcessor* c = static_cast <BaseProcessor*>(tag);
281
274
GPR_ASSERT (ok);
282
275
PADDLE_ENFORCE (c);
283
276
if (c->status_ .ok ()) {
284
- VLOG (3 ) << c->var_h_ . String () << " process" ;
277
+ VLOG (3 ) << c->GetVarHandlePtr ()-> String () << " process" ;
285
278
c->Process ();
286
279
} else if (c->status_ .error_code () == grpc::StatusCode::DEADLINE_EXCEEDED) {
287
- LOG (ERROR) << c->var_h_ . String ()
280
+ LOG (ERROR) << c->GetVarHandlePtr ()-> String ()
288
281
<< " meets grpc error:" << c->status_ .error_message ();
289
282
{
290
283
std::lock_guard<std::mutex> lk (sync_mutex_);
291
284
ok_ = false ;
292
285
}
293
- sync_cond_. notify_all ( );
286
+ c-> Finish ( false );
294
287
} else {
295
- LOG (FATAL) << c->var_h_ . String ()
288
+ LOG (FATAL) << c->GetVarHandlePtr ()-> String ()
296
289
<< " meets grpc error:" << c->status_ .error_message ();
290
+ c->Finish (false );
297
291
}
292
+
298
293
delete c;
299
294
{
300
295
std::lock_guard<std::mutex> lk (sync_mutex_);
301
296
req_count_--;
302
297
}
303
298
sync_cond_.notify_all ();
304
299
}
300
+ VLOG (3 ) << " GRPCClient Proceed end" ;
305
301
}
306
302
307
303
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel (const std::string& ep) {
0 commit comments