@@ -96,19 +96,22 @@ static int64_t GetTimestamp() {
96
96
return tp.tv_sec * 1000 + tp.tv_usec / 1000 ;
97
97
}
98
98
99
- void ListenAndServOp::RunSyncLoop (framework::Executor *executor,
100
- framework::ProgramDesc *program,
101
- framework::Scope *recv_scope,
102
- framework::BlockDesc *prefetch_block ) const {
99
+ void ListenAndServOp::RunSyncLoop (
100
+ framework::Executor *executor, framework::ProgramDesc *program,
101
+ framework::Scope *recv_scope,
102
+ const std::vector< int > &prefetch_block_id_list ) const {
103
103
size_t num_blocks = program->Size ();
104
104
PADDLE_ENFORCE_GE (num_blocks, 2 ,
105
105
" server program should have at least 2 blocks" );
106
106
107
- std::vector<int > block_list;
108
- for (size_t blkid = 1 ; blkid < num_blocks; ++blkid) {
109
- block_list.push_back (blkid);
107
+ std::vector<int > optimize_block_id_list;
108
+ for (int blkid = 1 ; blkid < num_blocks; ++blkid) {
109
+ if (std::find (prefetch_block_id_list.begin (), prefetch_block_id_list.end (),
110
+ blkid) == prefetch_block_id_list.end ()) {
111
+ optimize_block_id_list.push_back (blkid);
112
+ }
110
113
}
111
- auto optimize_prepared = executor->Prepare (*program, block_list );
114
+ auto optimize_prepared = executor->Prepare (*program, optimize_block_id_list );
112
115
// Insert placeholder for block0 which holds current op itself.
113
116
optimize_prepared.insert (
114
117
optimize_prepared.begin (),
@@ -135,16 +138,17 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
135
138
std::vector<size_t > parallel_blkids;
136
139
parallel_blkids.push_back (1 );
137
140
double ts = GetTimestamp ();
138
- for (size_t blkid = 2 ; blkid < num_blocks ; ++blkid ) {
139
- if (blkid != static_cast < size_t >(prefetch_block-> ID ())) {
140
- if (program-> Block (blkid). Parent () != last_parent_blkid) {
141
- ParallelExecuteBlocks (parallel_blkids, executor, optimize_prepared,
142
- program, recv_scope);
143
- parallel_blkids. clear ();
144
- last_parent_blkid = program-> Block (blkid). Parent ( );
145
- }
146
- parallel_blkids. push_back (blkid);
141
+ for (size_t i = 1 ; i < optimize_block_id_list. size () ; ++i ) {
142
+ // skip the first optimize block because it is already in the
143
+ // parallel_blkids.
144
+ int blkid = optimize_block_id_list[i];
145
+ if ( program-> Block (blkid). Parent () != last_parent_blkid) {
146
+ ParallelExecuteBlocks (parallel_blkids, executor, optimize_prepared,
147
+ program, recv_scope );
148
+ parallel_blkids. clear ();
149
+ last_parent_blkid = program-> Block (blkid). Parent ( );
147
150
}
151
+ parallel_blkids.push_back (blkid);
148
152
}
149
153
ParallelExecuteBlocks (parallel_blkids, executor, optimize_prepared, program,
150
154
recv_scope);
@@ -210,18 +214,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
210
214
} // while(true)
211
215
}
212
216
213
- static void FillRequestCtx (detail::RequestHandler *h, framework::Scope *scope,
214
- platform::DeviceContext *dev_ctx,
215
- framework::Executor *executor,
216
- framework::ProgramDesc *program,
217
- framework::ExecutorPrepareContext *prefetch_ctx,
218
- detail::RPCServer *rpc_server) {
217
+ static void FillRequestCtx (
218
+ detail::RequestHandler *h, framework::Scope *scope,
219
+ platform::DeviceContext *dev_ctx, framework::Executor *executor,
220
+ framework::ProgramDesc *program,
221
+ std::unordered_map<std::string,
222
+ std::shared_ptr<framework::ExecutorPrepareContext>>
223
+ *prefetch_ctx,
224
+ detail::RPCServer *rpc_server) {
219
225
h->SetScope (scope);
220
226
h->SetDevCtx (dev_ctx);
221
227
h->SetExecutor (executor);
222
228
h->SetProgram (program);
223
- h->SetPrefetchPreparedCtx (
224
- std::unique_ptr<framework::ExecutorPrepareContext>(prefetch_ctx));
229
+ h->SetPrefetchPreparedCtx (prefetch_ctx);
225
230
h->SetRPCServer (rpc_server);
226
231
}
227
232
@@ -255,17 +260,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
255
260
request_prefetch_handler_.get ());
256
261
257
262
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock );
258
- auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock );
259
263
auto *program = optimize_block->Program ();
260
264
framework::Executor executor (dev_place);
261
265
262
266
// prepare for prefetch
263
- VLOG (3 ) << " prefetch block id is " << prefetch_block->ID ();
264
- auto prefetch_prepared = executor.Prepare (*program, prefetch_block->ID ());
267
+ std::vector<int > prefetch_block_id_list;
268
+ std::unordered_map<int , std::string> block_id_to_prefetch_var_name;
269
+
270
+ auto prefetch_var_name_to_block_id_str =
271
+ Attr<std::vector<std::string>>(kPrefetchVarNameToBlockId );
272
+ for (const auto &prefetch_var_name_and_id :
273
+ prefetch_var_name_to_block_id_str) {
274
+ std::vector<std::string> pieces;
275
+ split (prefetch_var_name_and_id, ' :' , &pieces);
276
+ VLOG (3 ) << " after split, prefetch_var = " << pieces[0 ]
277
+ << " , id=" << pieces[1 ];
278
+ PADDLE_ENFORCE_EQ (pieces.size (), 2 );
279
+
280
+ int block_id = std::stoi (pieces[1 ]);
281
+ prefetch_block_id_list.push_back (block_id);
282
+ block_id_to_prefetch_var_name[block_id] = pieces[0 ];
283
+ }
284
+
285
+ auto prefetch_prepared = executor.Prepare (*program, prefetch_block_id_list);
286
+
287
+ std::unordered_map<std::string,
288
+ std::shared_ptr<framework::ExecutorPrepareContext>>
289
+ prefetch_var_name_to_prepared_ctx;
290
+ for (size_t i = 0 ; i < prefetch_block_id_list.size (); ++i) {
291
+ auto block_id = prefetch_block_id_list[i];
292
+ auto prefetch_var_name = block_id_to_prefetch_var_name[block_id];
293
+ prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i];
294
+ }
265
295
266
296
auto f = std::bind (FillRequestCtx, std::placeholders::_1, &recv_scope,
267
- &dev_ctx, &executor, program, prefetch_prepared. release (),
268
- rpc_service_.get ());
297
+ &dev_ctx, &executor, program,
298
+ &prefetch_var_name_to_prepared_ctx, rpc_service_.get ());
269
299
270
300
f (request_send_handler_.get ());
271
301
f (request_get_handler_.get ());
@@ -283,7 +313,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
283
313
// Write to a file of server selected port for python use.
284
314
SavePort ();
285
315
if (sync_mode) {
286
- RunSyncLoop (&executor, program, &recv_scope, prefetch_block );
316
+ RunSyncLoop (&executor, program, &recv_scope, prefetch_block_id_list );
287
317
} else {
288
318
RunAsyncLoop (&executor, program);
289
319
}
@@ -309,8 +339,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
309
339
AddAttr<bool >(" sync_mode" , " if works at sync_mode or not" ).SetDefault (true );
310
340
AddAttr<framework::BlockDesc *>(kOptimizeBlock ,
311
341
" BlockID to run on server side." );
312
- AddAttr<framework::BlockDesc *>(kPrefetchBlock ,
313
- " prefetch block to run on server side." );
342
+ AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId ,
343
+ " prefetch blocks to run on server side." )
344
+ .SetDefault ({});
314
345
AddAttr<int >(" Fanin" , " How many clients send to this server." )
315
346
.SetDefault (1 );
316
347
}
0 commit comments