@@ -62,7 +62,7 @@ VarHandlePtr BRPCClient::AsyncSendVar(const std::string& ep,
62
62
const std::string var_name_val = var_name;
63
63
const framework::Scope* p_scope = &scope;
64
64
const auto ch_ptr = GetChannel (ep_val);
65
- const std::string method = " SendRPC " ;
65
+ const std::string method = kSendRPC ;
66
66
VarHandlePtr var_h (new VarHandle (ep, method, var_name_val, p_ctx, p_scope));
67
67
68
68
framework::AsyncIO ([=] {
@@ -156,15 +156,18 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
156
156
const platform::DeviceContext& ctx,
157
157
const framework::Scope& scope,
158
158
const std::string& var_name,
159
+ const std::string& out_var_name,
159
160
const std::string& method_name,
160
161
int64_t time_out) {
161
162
const platform::DeviceContext* p_ctx = &ctx;
162
163
const std::string ep_val = ep;
163
164
const std::string var_name_val = var_name;
165
+ const std::string out_varname_val = out_var_name;
164
166
const framework::Scope* p_scope = &scope;
165
167
const auto ch_ptr = GetChannel (ep_val);
166
- const std::string method = " GetRPC" ;
167
- VarHandlePtr var_h (new VarHandle (ep, method, var_name_val, p_ctx, p_scope));
168
+ const std::string method = kGetRPC ;
169
+ VarHandlePtr var_h (
170
+ new VarHandle (ep, method, out_varname_val, p_ctx, p_scope));
168
171
169
172
framework::AsyncIO ([=] {
170
173
auto ch_ctx = ch_ptr->Pop ();
@@ -175,15 +178,18 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
175
178
176
179
sendrecv::VariableMessage req;
177
180
req.set_varname (var_name_val);
181
+ req.set_out_varname (out_varname_val);
178
182
req.set_trainer_id (trainer_id_);
179
183
180
184
google::protobuf::Closure* done = brpc::NewCallback (
181
185
&HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this );
182
186
183
187
platform::RecordRPCEvent record_event (method, p_ctx);
184
188
185
- if (method_name == " GetMonomerVariable " ) {
189
+ if (method_name == kGetMonomerRPC ) {
186
190
ch_ctx->stub ->GetMonomerVariable (cntl, &req, response, done);
191
+ } else if (method_name == kGetNoBarrierRPC ) {
192
+ ch_ctx->stub ->GetVariableNoBarrier (cntl, &req, response, done);
187
193
} else {
188
194
ch_ctx->stub ->GetVariable (cntl, &req, response, done);
189
195
}
@@ -198,25 +204,39 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
198
204
return var_h;
199
205
}
200
206
207
+ VarHandlePtr BRPCClient::AsyncGetVarNoBarrier (
208
+ const std::string& ep, const platform::DeviceContext& ctx,
209
+ const framework::Scope& scope, const std::string& var_name,
210
+ const std::string& out_var_name, int64_t time_out) {
211
+ std::string var_name_no_barrier =
212
+ string::Sprintf (" %s%s" , var_name, WITHOUT_BARRIER_MESSAGE);
213
+
214
+ return _AsyncGetVar (ep, ctx, scope, var_name_no_barrier, out_var_name,
215
+ kGetNoBarrierRPC , time_out);
216
+ }
217
+
201
218
VarHandlePtr BRPCClient::AsyncGetMonomerVariable (
202
219
const std::string& ep, const platform::DeviceContext& ctx,
203
220
const framework::Scope& scope, const std::string& var_name,
204
221
int64_t time_out) {
205
- return _AsyncGetVar (ep, ctx, scope, var_name, " GetMonomerVariable" , time_out);
222
+ return _AsyncGetVar (ep, ctx, scope, var_name, var_name, kGetMonomerRPC ,
223
+ time_out);
206
224
}
207
225
208
226
VarHandlePtr BRPCClient::AsyncGetMonomerBarrier (const std::string& ep,
209
227
const std::string& var_name,
210
228
int64_t time_out) {
211
- return AsyncSendMessage (ep, " GetMonomerBarrier " , var_name, time_out);
229
+ return AsyncSendMessage (ep, kSendMonomerFetchBarrierRPC , var_name, time_out);
212
230
}
213
231
214
232
VarHandlePtr BRPCClient::AsyncGetVar (const std::string& ep,
215
233
const platform::DeviceContext& ctx,
216
234
const framework::Scope& scope,
217
235
const std::string& var_name,
236
+ const std::string& out_var_name,
218
237
int64_t time_out) {
219
- return _AsyncGetVar (ep, ctx, scope, var_name, " GetVariable" , time_out);
238
+ return _AsyncGetVar (ep, ctx, scope, var_name, out_var_name, kGetRPC ,
239
+ time_out);
220
240
}
221
241
222
242
VarHandlePtr BRPCClient::AsyncPrefetchVar (const std::string& ep,
@@ -234,7 +254,7 @@ VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
234
254
const framework::Scope* p_scope = &scope;
235
255
const auto ch_ptr = GetChannel (ep_val);
236
256
237
- const std::string method = " PrefetchRPC " ;
257
+ const std::string method = kPrefetchRPC ;
238
258
239
259
VarHandlePtr var_h (
240
260
new VarHandle (ep, method, out_var_name_val, p_ctx, p_scope));
@@ -270,7 +290,7 @@ VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
270
290
271
291
VarHandlePtr BRPCClient::AsyncSendBatchBarrier (const std::string& ep,
272
292
int64_t time_out) {
273
- return AsyncSendMessage (ep, " BatchBarrierRPC " , BATCH_BARRIER_MESSAGE,
293
+ return AsyncSendMessage (ep, kBatchBarrierRPC , BATCH_BARRIER_MESSAGE,
274
294
time_out);
275
295
}
276
296
@@ -286,7 +306,7 @@ VarHandlePtr BRPCClient::AsyncSendFetchBarrier(const std::string& ep,
286
306
sendrecv::VariableMessage req;
287
307
req.set_varname (FETCH_BARRIER_MESSAGE);
288
308
289
- const std::string method = " FetchBarrierRPC " ;
309
+ const std::string method = kFetchBarrierRPC ;
290
310
// var handle
291
311
VarHandlePtr var_h (
292
312
new VarHandle (ep, method, FETCH_BARRIER_MESSAGE, nullptr , nullptr ));
@@ -367,7 +387,7 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
367
387
368
388
VarHandlePtr BRPCClient::AsyncSendComplete (const std::string& ep,
369
389
int64_t time_out) {
370
- return AsyncSendMessage (ep, " SendCompleteRPC " , COMPLETE_MESSAGE, time_out);
390
+ return AsyncSendMessage (ep, kSendCompleteRPC , COMPLETE_MESSAGE, time_out);
371
391
}
372
392
373
393
void BRPCClient::SendComplete () {
@@ -394,9 +414,9 @@ VarHandlePtr BRPCClient::AsyncSendVarMessage(
394
414
google::protobuf::Closure* done = brpc::NewCallback (
395
415
&HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this );
396
416
397
- if (method_name == " CheckPointNotifyRPC " ) {
417
+ if (method_name == kCheckPointNotifyRPC ) {
398
418
ch_ctx->stub ->CheckpointNotify (cntl, &req, response, done);
399
- } else if (method_name == " GetMonomerBarrier " ) {
419
+ } else if (method_name == kSendMonomerFetchBarrierRPC ) {
400
420
ch_ctx->stub ->GetMonomerBarrier (cntl, &req, response, done);
401
421
} else {
402
422
ch_ctx->stub ->SendVariable (cntl, &req, response, done);
0 commit comments