@@ -89,16 +89,19 @@ void ListenAndServOp::SavePort() const {
89
89
rpc_service_->SavePort ();
90
90
}
91
91
92
- void ListenAndServOp::RunSyncLoop (framework::Executor *executor,
93
- framework::ProgramDesc *program,
94
- framework::Scope *recv_scope,
95
- framework::BlockDesc *prefetch_block) const {
92
+ void ListenAndServOp::RunSyncLoop (
93
+ framework::Executor *executor, framework::ProgramDesc *program,
94
+ framework::Scope *recv_scope,
95
+ const std::vector<int > &prefetch_block_id_list) const {
96
+ // FIXME(qiao) run should not run the block to do prefetch, currently prefetch
97
+ // block
98
+ // can only be at the last blocks of the program
96
99
size_t num_blocks = program->Size ();
97
100
PADDLE_ENFORCE_GE (num_blocks, 2 ,
98
101
" server program should have at least 2 blocks" );
99
102
100
103
std::vector<int > block_list;
101
- for (size_t blkid = 1 ; blkid < num_blocks ; ++blkid) {
104
+ for (size_t blkid = 1 ; blkid < prefetch_block_id_list[ 0 ] ; ++blkid) {
102
105
block_list.push_back (blkid);
103
106
}
104
107
auto optimize_prepared = executor->Prepare (*program, block_list);
@@ -128,16 +131,14 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
128
131
std::vector<size_t > parallel_blkids;
129
132
parallel_blkids.push_back (1 );
130
133
double ts = detail::GetTimestamp ();
131
- for (size_t blkid = 2 ; blkid < num_blocks; ++blkid) {
132
- if (blkid != static_cast <size_t >(prefetch_block->ID ())) {
133
- if (program->Block (blkid).Parent () != last_parent_blkid) {
134
- ParallelExecuteBlocks (parallel_blkids, executor, optimize_prepared,
135
- program, recv_scope);
136
- parallel_blkids.clear ();
137
- last_parent_blkid = program->Block (blkid).Parent ();
138
- }
139
- parallel_blkids.push_back (blkid);
134
+ for (size_t blkid = 2 ; blkid < prefetch_block_id_list[0 ]; ++blkid) {
135
+ if (program->Block (blkid).Parent () != last_parent_blkid) {
136
+ ParallelExecuteBlocks (parallel_blkids, executor, optimize_prepared,
137
+ program, recv_scope);
138
+ parallel_blkids.clear ();
139
+ last_parent_blkid = program->Block (blkid).Parent ();
140
140
}
141
+ parallel_blkids.push_back (blkid);
141
142
}
142
143
ParallelExecuteBlocks (parallel_blkids, executor, optimize_prepared, program,
143
144
recv_scope);
@@ -203,18 +204,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
203
204
} // while(true)
204
205
}
205
206
206
- static void FillRequestCtx (detail::RequestHandler *h, framework::Scope *scope,
207
- platform::DeviceContext *dev_ctx,
208
- framework::Executor *executor,
209
- framework::ProgramDesc *program,
210
- framework::ExecutorPrepareContext *prefetch_ctx,
211
- detail::RPCServer *rpc_server) {
207
+ static void FillRequestCtx (
208
+ detail::RequestHandler *h, framework::Scope *scope,
209
+ platform::DeviceContext *dev_ctx, framework::Executor *executor,
210
+ framework::ProgramDesc *program,
211
+ std::unordered_map<std::string,
212
+ std::shared_ptr<framework::ExecutorPrepareContext>>
213
+ *prefetch_ctx,
214
+ detail::RPCServer *rpc_server) {
212
215
h->SetScope (scope);
213
216
h->SetDevCtx (dev_ctx);
214
217
h->SetExecutor (executor);
215
218
h->SetProgram (program);
216
- h->SetPrefetchPreparedCtx (
217
- std::unique_ptr<framework::ExecutorPrepareContext>(prefetch_ctx));
219
+ h->SetPrefetchPreparedCtx (prefetch_ctx);
218
220
h->SetRPCServer (rpc_server);
219
221
}
220
222
@@ -248,18 +250,41 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
248
250
request_prefetch_handler_.get ());
249
251
250
252
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock );
251
- auto grad_to_block_id_str = Attr<std::vector<std::string>>(kPrefetchBlock );
252
- framework::BlockDesc *prefetch_block = nullptr ;
253
253
auto *program = optimize_block->Program ();
254
254
framework::Executor executor (dev_place);
255
255
256
256
// prepare for prefetch
257
- VLOG (3 ) << " prefetch block id is " << prefetch_block->ID ();
258
- auto prefetch_prepared = executor.Prepare (*program, prefetch_block->ID ());
257
+ std::vector<int > prefetch_block_id_list;
258
+ std::unordered_map<int32_t , std::string> block_id_to_prefetch_var_name;
259
+
260
+ auto prefetch_var_name_to_block_id_str =
261
+ Attr<std::vector<std::string>>(kPrefetchVarNameToBlockId );
262
+ for (const auto &prefetch_var_name_and_id :
263
+ prefetch_var_name_to_block_id_str) {
264
+ std::vector<std::string> pieces;
265
+ split (prefetch_var_name_and_id, ' :' , &pieces);
266
+ VLOG (3 ) << " after split, grad = " << pieces[0 ] << " , id=" << pieces[1 ];
267
+ PADDLE_ENFORCE_EQ (pieces.size (), 2 );
268
+
269
+ int block_id = std::stoi (pieces[1 ]);
270
+ prefetch_block_id_list.push_back (block_id);
271
+ block_id_to_prefetch_var_name[block_id] = pieces[0 ];
272
+ }
273
+
274
+ auto prefetch_prepared = executor.Prepare (*program, prefetch_block_id_list);
275
+
276
+ std::unordered_map<std::string,
277
+ std::shared_ptr<framework::ExecutorPrepareContext>>
278
+ prefetch_var_name_to_prepared_ctx;
279
+ for (int i = 0 ; i < prefetch_block_id_list.size (); ++i) {
280
+ auto block_id = prefetch_block_id_list[i];
281
+ auto prefetch_var_name = block_id_to_prefetch_var_name[block_id];
282
+ prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i];
283
+ }
259
284
260
285
auto f = std::bind (FillRequestCtx, std::placeholders::_1, &recv_scope,
261
- &dev_ctx, &executor, program, prefetch_prepared. release (),
262
- rpc_service_.get ());
286
+ &dev_ctx, &executor, program,
287
+ &prefetch_var_name_to_prepared_ctx, rpc_service_.get ());
263
288
264
289
f (request_send_handler_.get ());
265
290
f (request_get_handler_.get ());
@@ -277,7 +302,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
277
302
// Write to a file of server selected port for python use.
278
303
SavePort ();
279
304
if (sync_mode) {
280
- RunSyncLoop (&executor, program, &recv_scope, prefetch_block );
305
+ RunSyncLoop (&executor, program, &recv_scope, prefetch_block_id_list );
281
306
} else {
282
307
RunAsyncLoop (&executor, program);
283
308
}
@@ -303,7 +328,7 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
303
328
AddAttr<bool >(" sync_mode" , " if works at sync_mode or not" ).SetDefault (true );
304
329
AddAttr<framework::BlockDesc *>(kOptimizeBlock ,
305
330
" BlockID to run on server side." );
306
- AddAttr<std::vector<std::string>>(kPrefetchBlock ,
331
+ AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId ,
307
332
" prefetch block to run on server side." );
308
333
AddAttr<int >(" Fanin" , " How many clients send to this server." )
309
334
.SetDefault (1 );
0 commit comments