@@ -45,20 +45,6 @@ static void split(const std::string &str, char sep,
45
45
}
46
46
}
47
47
48
- static void AsyncExecuteBlock (framework::Executor *executor,
49
- framework::ExecutorPrepareContext *prepared,
50
- framework::Scope *scope) {
51
- std::future<void > future = framework::Async ([&executor, &prepared, &scope]() {
52
- try {
53
- executor->RunPreparedContext (prepared, scope, false , false );
54
- } catch (std::exception &e) {
55
- LOG (ERROR) << " run sub program error " << e.what ();
56
- }
57
- });
58
- // TODO(qiao) maybe we can remove this
59
- future.wait ();
60
- }
61
-
62
48
static void ParallelExecuteBlocks (
63
49
const std::vector<size_t > ¶llel_blkids, framework::Executor *executor,
64
50
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
@@ -201,14 +187,40 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
201
187
} // while(true)
202
188
}
203
189
190
+ static void AsyncUpdateThread (
191
+ const std::string &var_name, const bool &exit_flag,
192
+ const std::shared_ptr<detail::ReceivedQueue> &queue,
193
+ framework::Executor *executor,
194
+ framework::ExecutorPrepareContext *prepared) {
195
+ VLOG (3 ) << " update thread for " << var_name << " started" ;
196
+ while (!exit_flag) {
197
+ const detail::ReceivedMessage v = queue->Pop ();
198
+ auto recv_var_name = v.first ;
199
+ auto var = v.second ->GetVar ();
200
+ if (var == nullptr ) {
201
+ LOG (ERROR) << " Can not find server side var: " << recv_var_name;
202
+ PADDLE_THROW (" Can not find server side var" );
203
+ }
204
+ auto fs = framework::Async ([var_name, &executor, &v, prepared] {
205
+ try {
206
+ executor->RunPreparedContext (prepared, v.second ->GetMutableLocalScope (),
207
+ false , false );
208
+ } catch (std::exception &e) {
209
+ LOG (ERROR) << " run sub program error " << e.what ();
210
+ }
211
+ });
212
+ fs.wait ();
213
+ }
214
+ }
215
+
204
216
void ListenAndServOp::RunAsyncLoop (framework::Executor *executor,
205
- framework::ProgramDesc *program,
206
- framework::Scope *recv_scope,
207
- framework::BlockDesc *prefetch_block) const {
217
+ framework::ProgramDesc *program) const {
208
218
VLOG (3 ) << " RunAsyncLoop in" ;
209
219
// grad name to block id
210
220
std::unordered_map<std::string, int32_t > grad_to_block_id;
211
221
std::unordered_map<int32_t , std::string> id_to_grad;
222
+ std::unordered_map<std::string, std::shared_ptr<detail::ReceivedQueue>>
223
+ grad_to_queue;
212
224
213
225
auto grad_to_block_id_str =
214
226
Attr<std::vector<std::string>>(" grad_to_block_id" );
@@ -220,6 +232,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
220
232
PADDLE_ENFORCE_EQ (grad_to_block_id.count (pieces[0 ]), 0 );
221
233
int block_id = std::stoi (pieces[1 ]);
222
234
grad_to_block_id[pieces[0 ]] = block_id;
235
+ grad_to_queue[pieces[0 ]] = std::make_shared<detail::ReceivedQueue>();
223
236
id_to_grad[block_id] = pieces[0 ];
224
237
}
225
238
size_t num_blocks = program->Size ();
@@ -238,8 +251,21 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
238
251
grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i];
239
252
}
240
253
241
- VLOG (3 ) << " RunAsyncLoop into while" ;
242
254
bool exit_flag = false ;
255
+
256
+ VLOG (3 ) << " start async optimize threads" ;
257
+ std::vector<std::future<void >> fs;
258
+ for (auto iter = grad_to_queue.begin (); iter != grad_to_queue.end (); iter++) {
259
+ std::string grad_name = iter->first ;
260
+ VLOG (3 ) << " create async update thread for " << grad_name;
261
+ fs.push_back (framework::AsyncIO ([grad_name, &exit_flag, &executor,
262
+ &grad_to_queue, &grad_to_prepared_ctx]() {
263
+ AsyncUpdateThread (grad_name, exit_flag, grad_to_queue[grad_name],
264
+ executor, grad_to_prepared_ctx[grad_name].get ());
265
+ }));
266
+ }
267
+
268
+ VLOG (3 ) << " RunAsyncLoop into while" ;
243
269
while (!exit_flag) {
244
270
const detail::ReceivedMessage v = rpc_service_->Get ();
245
271
auto recv_var_name = v.first ;
@@ -249,13 +275,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
249
275
break ;
250
276
} else {
251
277
VLOG (3 ) << " received grad: " << recv_var_name;
252
- auto var = v.second ->GetVar ();
253
- if (var == nullptr ) {
254
- LOG (ERROR) << " Can not find server side var: " << recv_var_name;
255
- PADDLE_THROW (" Can not find server side var" );
256
- }
257
- AsyncExecuteBlock (executor, grad_to_prepared_ctx[recv_var_name].get (),
258
- v.second ->GetMutableLocalScope ());
278
+ grad_to_queue[recv_var_name]->Push (v);
259
279
}
260
280
261
281
if (exit_flag) {
@@ -304,7 +324,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
304
324
if (sync_mode) {
305
325
RunSyncLoop (&executor, program, &recv_scope, prefetch_block);
306
326
} else {
307
- RunAsyncLoop (&executor, program, &recv_scope, prefetch_block );
327
+ RunAsyncLoop (&executor, program);
308
328
}
309
329
}
310
330
0 commit comments