@@ -73,10 +73,11 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
73
73
const framework::Scope* p_scope = &scope;
74
74
const auto ch = GetChannel (ep_val);
75
75
SendProcessor* s = new SendProcessor (ch);
76
- VarHandlePtr h (new VarHandle (ep, " Send" , var_name_val, p_ctx, p_scope));
76
+ const std::string method = " SendRPC" ;
77
+ VarHandlePtr h (new VarHandle (ep, method, var_name_val, p_ctx, p_scope));
77
78
s->Prepare (h, time_out);
78
79
79
- framework::AsyncIO ([var_name_val, p_scope, p_ctx, s, this ] {
80
+ framework::AsyncIO ([var_name_val, p_scope, p_ctx, s, method, h, this ] {
80
81
auto * var = p_scope->FindVar (var_name_val);
81
82
82
83
::grpc::ByteBuffer req;
@@ -87,10 +88,16 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
87
88
// stub context
88
89
s->response_call_back_ = nullptr ;
89
90
91
+ platform::RecordEvent record_event (method, p_ctx);
92
+
90
93
auto call = s->stub_g_ .PrepareUnaryCall (
91
94
s->context_ .get (), " /sendrecv.SendRecvService/SendVariable" , req, &cq_);
92
95
call->StartCall ();
93
96
call->Finish (&s->reply_ , &s->status_ , reinterpret_cast <void *>(s));
97
+
98
+ if (UNLIKELY (platform::IsProfileEnabled ())) {
99
+ h->Wait ();
100
+ }
94
101
});
95
102
req_count_++;
96
103
@@ -122,10 +129,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
122
129
const framework::Scope* p_scope = &scope;
123
130
const auto ch = GetChannel (ep_val);
124
131
GetProcessor* s = new GetProcessor (ch);
125
- VarHandlePtr h (new VarHandle (ep, " Get" , var_name_val, p_ctx, p_scope));
132
+ const std::string method = " GetRPC" ;
133
+ VarHandlePtr h (new VarHandle (ep, method, var_name_val, p_ctx, p_scope));
126
134
s->Prepare (h, time_out);
127
135
128
- framework::AsyncIO ([var_name_val, s, this ] {
136
+ framework::AsyncIO ([var_name_val, s, method, p_ctx, h, this ] {
129
137
// prepare input
130
138
sendrecv::VariableMessage req;
131
139
req.set_varname (var_name_val);
@@ -137,10 +145,16 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
137
145
// stub context
138
146
s->response_call_back_ = ProcGetResponse;
139
147
148
+ platform::RecordEvent record_event (method, p_ctx);
149
+
140
150
auto call = s->stub_g_ .PrepareUnaryCall (
141
151
s->context_ .get (), " /sendrecv.SendRecvService/GetVariable" , buf, &cq_);
142
152
call->StartCall ();
143
153
call->Finish (&s->reply_ , &s->status_ , reinterpret_cast <void *>(s));
154
+
155
+ if (UNLIKELY (platform::IsProfileEnabled ())) {
156
+ h->Wait ();
157
+ }
144
158
});
145
159
146
160
req_count_++;
@@ -161,12 +175,14 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
161
175
const framework::Scope* p_scope = &scope;
162
176
const auto ch = GetChannel (ep_val);
163
177
GetProcessor* s = new GetProcessor (ch);
164
- VarHandlePtr h (
165
- new VarHandle (ep, " Prefetch" , out_var_name_val, p_ctx, p_scope));
178
+
179
+ const std::string method = " PrefetchRPC" ;
180
+
181
+ VarHandlePtr h (new VarHandle (ep, method, out_var_name_val, p_ctx, p_scope));
166
182
s->Prepare (h, time_out);
167
183
168
184
framework::AsyncIO ([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
169
- s, this ] {
185
+ s, method, h, this ] {
170
186
auto * var = p_scope->FindVar (in_var_name_val);
171
187
172
188
::grpc::ByteBuffer req;
@@ -177,11 +193,17 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
177
193
// stub context
178
194
s->response_call_back_ = ProcGetResponse;
179
195
196
+ platform::RecordEvent record_event (method, p_ctx);
197
+
180
198
auto call = s->stub_g_ .PrepareUnaryCall (
181
199
s->context_ .get (), " /sendrecv.SendRecvService/PrefetchVariable" , req,
182
200
&cq_);
183
201
call->StartCall ();
184
202
call->Finish (&s->reply_ , &s->status_ , static_cast <void *>(s));
203
+
204
+ if (UNLIKELY (platform::IsProfileEnabled ())) {
205
+ h->Wait ();
206
+ }
185
207
});
186
208
187
209
req_count_++;
@@ -193,31 +215,49 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
193
215
const auto ch = GetChannel (ep);
194
216
195
217
BatchBarrierProcessor* s = new BatchBarrierProcessor (ch);
196
- VarHandlePtr h (new VarHandle (ep, " BatchBarrier" , BATCH_BARRIER_MESSAGE,
197
- nullptr , nullptr ));
218
+ const std::string method = " BatchBarrierRPC" ;
219
+ VarHandlePtr h (
220
+ new VarHandle (ep, method, BATCH_BARRIER_MESSAGE, nullptr , nullptr ));
198
221
s->Prepare (h, time_out);
199
222
200
223
sendrecv::VariableMessage req;
201
224
req.set_varname (BATCH_BARRIER_MESSAGE);
225
+
226
+ platform::RecordEvent record_event (method, nullptr );
227
+
202
228
auto rpc = s->stub_ ->AsyncSendVariable (s->context_ .get (), req, &cq_);
203
229
rpc->Finish (&s->reply_ , &s->status_ , reinterpret_cast <void *>(s));
204
230
req_count_++;
231
+
232
+ if (UNLIKELY (platform::IsProfileEnabled ())) {
233
+ h->Wait ();
234
+ }
235
+
205
236
return h;
206
237
}
207
238
208
239
VarHandlePtr GRPCClient::AsyncSendFetchBarrier (const std::string& ep,
209
240
int64_t time_out) {
210
241
const auto ch = GetChannel (ep);
211
242
FetchBarrierProcessor* s = new FetchBarrierProcessor (ch);
212
- VarHandlePtr h (new VarHandle (ep, " FetchBarrier" , FETCH_BARRIER_MESSAGE,
213
- nullptr , nullptr ));
243
+ const std::string method = " FetchBarrierRPC" ;
244
+ VarHandlePtr h (
245
+ new VarHandle (ep, method, FETCH_BARRIER_MESSAGE, nullptr , nullptr ));
214
246
s->Prepare (h, time_out);
215
247
216
248
sendrecv::VariableMessage req;
217
249
req.set_varname (FETCH_BARRIER_MESSAGE);
250
+
251
+ platform::RecordEvent record_event (method, nullptr );
252
+
218
253
auto rpc = s->stub_ ->AsyncGetVariable (s->context_ .get (), req, &cq_);
219
254
rpc->Finish (&s->reply_ , &s->status_ , reinterpret_cast <void *>(s));
220
255
req_count_++;
256
+
257
+ if (UNLIKELY (platform::IsProfileEnabled ())) {
258
+ h->Wait ();
259
+ }
260
+
221
261
return h;
222
262
}
223
263
@@ -226,15 +266,23 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
226
266
const auto ch = GetChannel (ep);
227
267
228
268
BatchBarrierProcessor* s = new BatchBarrierProcessor (ch);
229
- VarHandlePtr h (
230
- new VarHandle (ep, " SendComplete " , COMPLETE_MESSAGE, nullptr , nullptr ));
269
+ const std::string method = " SendCompleteRPC " ;
270
+ VarHandlePtr h ( new VarHandle (ep, method , COMPLETE_MESSAGE, nullptr , nullptr ));
231
271
s->Prepare (h, time_out);
232
272
233
273
sendrecv::VariableMessage req;
234
274
req.set_varname (COMPLETE_MESSAGE);
275
+
276
+ platform::RecordEvent record_event (method, nullptr );
277
+
235
278
auto rpc = s->stub_ ->AsyncSendVariable (s->context_ .get (), req, &cq_);
236
279
rpc->Finish (&s->reply_ , &s->status_ , reinterpret_cast <void *>(s));
237
280
req_count_++;
281
+
282
+ if (UNLIKELY (platform::IsProfileEnabled ())) {
283
+ h->Wait ();
284
+ }
285
+
238
286
return h;
239
287
}
240
288
@@ -244,17 +292,27 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
244
292
const auto ch = GetChannel (ep);
245
293
246
294
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor (ch);
247
- VarHandlePtr h (new VarHandle (ep, " CheckPointNotify" , CHECKPOINT_SAVE_MESSAGE,
248
- nullptr , nullptr ));
295
+
296
+ const std::string method = " CheckPointNotifyRPC" ;
297
+
298
+ VarHandlePtr h (
299
+ new VarHandle (ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr , nullptr ));
249
300
s->Prepare (h, time_out);
250
301
251
302
sendrecv::VariableMessage req;
252
303
req.set_varname (CHECKPOINT_SAVE_MESSAGE);
253
304
req.set_out_varname (dir);
254
305
306
+ platform::RecordEvent record_event (method, nullptr );
307
+
255
308
auto rpc = s->stub_ ->AsyncCheckpointNotify (s->context_ .get (), req, &cq_);
256
309
rpc->Finish (&s->reply_ , &s->status_ , reinterpret_cast <void *>(s));
257
310
req_count_++;
311
+
312
+ if (UNLIKELY (platform::IsProfileEnabled ())) {
313
+ h->Wait ();
314
+ }
315
+
258
316
return h;
259
317
}
260
318
@@ -273,6 +331,7 @@ void GRPCClient::Proceed() {
273
331
BaseProcessor* c = static_cast <BaseProcessor*>(tag);
274
332
GPR_ASSERT (ok);
275
333
PADDLE_ENFORCE (c);
334
+
276
335
if (c->status_ .ok ()) {
277
336
VLOG (3 ) << c->GetVarHandlePtr ()->String () << " process" ;
278
337
c->Process ();
0 commit comments