@@ -207,18 +207,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
207
207
framework::BlockDesc *prefetch_block) const {
208
208
VLOG (3 ) << " RunAsyncLoop in" ;
209
209
// grad name to block id
210
- std::unordered_map<std::string, int32_t > grad_to_id ;
210
+ std::unordered_map<std::string, int32_t > grad_to_block_id ;
211
211
std::unordered_map<int32_t , std::string> id_to_grad;
212
212
213
- auto grad_to_id_str = Attr<std::vector<std::string>>(" grad_to_id" );
214
- for (auto &grad_and_id : grad_to_id_str) {
213
+ auto grad_to_block_id_str =
214
+ Attr<std::vector<std::string>>(" grad_to_block_id" );
215
+ for (auto &grad_and_id : grad_to_block_id_str) {
215
216
std::vector<std::string> pieces;
216
217
split (grad_and_id, ' :' , &pieces);
217
218
VLOG (3 ) << " after split, grad = " << pieces[0 ] << " , id=" << pieces[1 ];
218
219
PADDLE_ENFORCE_EQ (pieces.size (), 2 );
219
- PADDLE_ENFORCE_EQ (grad_to_id .count (pieces[0 ]), 0 );
220
+ PADDLE_ENFORCE_EQ (grad_to_block_id .count (pieces[0 ]), 0 );
220
221
int block_id = std::stoi (pieces[1 ]);
221
- grad_to_id [pieces[0 ]] = block_id;
222
+ grad_to_block_id [pieces[0 ]] = block_id;
222
223
id_to_grad[block_id] = pieces[0 ];
223
224
}
224
225
size_t num_blocks = program->Size ();
@@ -232,9 +233,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
232
233
auto optimize_prepared = executor->Prepare (*program, block_list);
233
234
std::unordered_map<std::string,
234
235
std::shared_ptr<framework::ExecutorPrepareContext>>
235
- grad_to_prepared ;
236
+ grad_to_prepared_block ;
236
237
for (size_t i = 0 ; i < block_list.size (); ++i) {
237
- grad_to_prepared [id_to_grad[block_list[i]]] = optimize_prepared[i];
238
+ grad_to_prepared_block [id_to_grad[block_list[i]]] = optimize_prepared[i];
238
239
}
239
240
240
241
VLOG (3 ) << " RunAsyncLoop into while" ;
@@ -253,8 +254,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
253
254
LOG (ERROR) << " Can not find server side var: " << recv_var_name;
254
255
PADDLE_THROW (" Can not find server side var" );
255
256
}
256
- AsyncExecuteBlock (executor, grad_to_prepared [recv_var_name].get (),
257
- &( v.second ->GetLocalScope () ));
257
+ AsyncExecuteBlock (executor, grad_to_prepared_block [recv_var_name].get (),
258
+ v.second ->GetMutableLocalScope ( ));
258
259
// TODO(qiao): explain why
259
260
if (var->IsType <framework::SelectedRows>()) {
260
261
var->GetMutable <framework::SelectedRows>()->mutable_rows ()->clear ();
@@ -328,7 +329,7 @@ from send_op and send back variables to recv_op.
328
329
.SetDefault (" 127.0.0.1:6164" )
329
330
.AddCustomChecker ([](const std::string &ip) { return !ip.empty (); });
330
331
AddAttr<std::vector<std::string>>(
331
- " grad_to_id " ,
332
+ " grad_to_block_id " ,
332
333
333
334
" a map from grad name to it's optimize block id" )
334
335
.SetDefault ({});
0 commit comments