@@ -104,8 +104,7 @@ void ListenAndServOp::RunSyncLoop(
104
104
framework::Executor *executor, framework::ProgramDesc *program,
105
105
framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
106
106
const std::vector<int > &prefetch_block_id_list,
107
- const int checkpoint_point_block_id,
108
- const std::vector<std::string> &recv_varnames) const {
107
+ const int checkpoint_point_block_id) const {
109
108
VLOG (2 ) << " RunSyncLoop" ;
110
109
size_t num_blocks = program->Size ();
111
110
auto optimize_blocks =
@@ -130,6 +129,7 @@ void ListenAndServOp::RunSyncLoop(
130
129
rpc_service_->SetCond (distributed::kRequestGet );
131
130
rpc_service_->WaitBarrier (distributed::kRequestGet );
132
131
rpc_service_->ResetBarrierCounter ();
132
+
133
133
while (true ) {
134
134
rpc_service_->Profiler ().OneStep ();
135
135
// Get from multiple trainers, we don't care about the order in which
@@ -167,19 +167,18 @@ void ListenAndServOp::RunSyncLoop(
167
167
recv_scope);
168
168
VLOG (2 ) << " run all blocks spent " << GetTimestamp () - ts << " (ms)" ;
169
169
170
- ResetReceivedVars (recv_varnames, recv_scope, dev_ctx,
171
- rpc_service_->NeedResetAllVars ());
170
+ ResetReceivedVars (recv_scope, dev_ctx, rpc_service_->NeedResetAllVars ());
172
171
173
172
rpc_service_->SetCond (distributed::kRequestGet );
174
173
rpc_service_->WaitBarrier (distributed::kRequestGet );
175
174
rpc_service_->ResetBarrierCounter ();
176
175
} // while(true)
177
176
}
178
177
179
- void ListenAndServOp::ResetReceivedVars (
180
- const std::vector<std::string> &recv_varnames, framework::Scope *recv_scope ,
181
- platform::DeviceContext *dev_ctx, bool reset_all) const {
182
- for (auto &varname : recv_varnames ) {
178
+ void ListenAndServOp::ResetReceivedVars (framework::Scope *recv_scope,
179
+ platform::DeviceContext *dev_ctx ,
180
+ bool reset_all) const {
181
+ for (auto &varname : sparse_vars_ ) {
183
182
auto var = recv_scope->FindVar (varname);
184
183
if (var == nullptr ) {
185
184
VLOG (2 ) << " can not find var " << varname << " in received scope" ;
@@ -188,18 +187,25 @@ void ListenAndServOp::ResetReceivedVars(
188
187
if (var->IsType <framework::SelectedRows>()) {
189
188
VLOG (3 ) << " reset sparse var: " << varname;
190
189
var->GetMutable <framework::SelectedRows>()->mutable_rows ()->clear ();
190
+ } else {
191
+ PADDLE_THROW (" The type of sparse var should be SelectedRows" );
191
192
}
192
- if (UNLIKELY (reset_all)) {
193
- VLOG (3 ) << " reset dense var: " << varname;
193
+ }
194
+ if (UNLIKELY (reset_all)) {
195
+ for (auto &varname : dense_vars_) {
196
+ auto var = recv_scope->FindVar (varname);
197
+ if (var == nullptr ) {
198
+ VLOG (2 ) << " can not find var " << varname << " in received scope" ;
199
+ continue ;
200
+ }
194
201
if (var->IsType <framework::LoDTensor>()) {
195
202
math::set_constant (*dev_ctx, var->GetMutable <framework::LoDTensor>(),
196
203
static_cast <float >(0 ));
197
204
} else if (var->IsType <framework::Tensor>()) {
198
205
math::set_constant (*dev_ctx, var->GetMutable <framework::Tensor>(),
199
206
static_cast <float >(0 ));
200
207
} else {
201
- PADDLE_THROW (
202
- " received var should be in [SelectedRows, LoDTensor, Tensor]" );
208
+ PADDLE_THROW (" The type of dense var should be in [LoDTensor, Tensor]" );
203
209
}
204
210
}
205
211
}
@@ -278,6 +284,25 @@ static void FillRequestCtx(
278
284
h->SetCheckpointNotifyPreparedCtx (checkpoint_ctx);
279
285
}
280
286
287
+ void ListenAndServOp::CacheVarsType (const std::vector<std::string> &varnames,
288
+ const framework::Scope &scope) const {
289
+ for (const auto &varname : varnames) {
290
+ auto var = scope.FindVar (varname);
291
+ PADDLE_ENFORCE (var != nullptr ,
292
+ " Received var should be initialized in the received scope." );
293
+ if (var->IsType <framework::SelectedRows>()) {
294
+ sparse_vars_.push_back (varname);
295
+ } else if (var->IsType <framework::LoDTensor>() ||
296
+ var->IsType <framework::Tensor>()) {
297
+ dense_vars_.push_back (varname);
298
+ } else {
299
+ PADDLE_THROW (
300
+ " The type of received var should be in [SelectedRows, LoDTensor, "
301
+ " Tensor]." );
302
+ }
303
+ }
304
+ }
305
+
281
306
void ListenAndServOp::RunImpl (const framework::Scope &scope,
282
307
const platform::Place &dev_place) const {
283
308
// Mark this as PS that it should decide profiling by listening from trainer.
@@ -379,11 +404,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
379
404
signal (SIGINT, SignalHandler::StopAndExit);
380
405
signal (SIGTERM, SignalHandler::StopAndExit);
381
406
407
+ // Cache the type of the received vars as `sparse_vars_` and `dense_vars_`
408
+ // so that we can reset them at the end of each iteration.
409
+ // NOTE: only used in sync update
410
+ CacheVarsType (inputs, recv_scope);
411
+
382
412
// Write to a file of server selected port for python use.
383
413
SavePort ();
384
414
if (sync_mode) {
385
415
RunSyncLoop (&executor, program, &recv_scope, &dev_ctx,
386
- prefetch_block_id_list, checkpoint_block_id, inputs );
416
+ prefetch_block_id_list, checkpoint_block_id);
387
417
} else {
388
418
RunAsyncLoop (&executor, program, &recv_scope);
389
419
}
0 commit comments