@@ -22,6 +22,7 @@ limitations under the License. */
22
22
#include " gflags/gflags.h"
23
23
24
24
#include " paddle/fluid/operators/detail/macros.h"
25
+ #include " paddle/fluid/operators/math/math_function.h"
25
26
26
27
#include " paddle/fluid/operators/distributed/request_handler_impl.h"
27
28
#include " paddle/fluid/operators/listen_and_serv_op.h"
@@ -101,7 +102,7 @@ static int64_t GetTimestamp() {
101
102
102
103
void ListenAndServOp::RunSyncLoop (
103
104
framework::Executor *executor, framework::ProgramDesc *program,
104
- framework::Scope *recv_scope,
105
+ framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
105
106
const std::vector<int > &prefetch_block_id_list,
106
107
const int checkpoint_point_block_id) const {
107
108
VLOG (2 ) << " RunSyncLoop" ;
@@ -128,6 +129,7 @@ void ListenAndServOp::RunSyncLoop(
128
129
rpc_service_->SetCond (distributed::kRequestGet );
129
130
rpc_service_->WaitBarrier (distributed::kRequestGet );
130
131
rpc_service_->ResetBarrierCounter ();
132
+
131
133
while (true ) {
132
134
rpc_service_->Profiler ().OneStep ();
133
135
// Get from multiple trainers, we don't care about the order in which
@@ -165,16 +167,50 @@ void ListenAndServOp::RunSyncLoop(
165
167
recv_scope);
166
168
VLOG (2 ) << " run all blocks spent " << GetTimestamp () - ts << " (ms)" ;
167
169
168
- // reset received sparse vars to avoid reuse it in the next mini-batch
169
- dynamic_cast <distributed::RequestSendHandler *>(request_send_handler_.get ())
170
- ->ResetSparseVarRecorder ();
170
+ ResetReceivedVars (recv_scope, dev_ctx, rpc_service_->NeedResetAllVars ());
171
171
172
172
rpc_service_->SetCond (distributed::kRequestGet );
173
173
rpc_service_->WaitBarrier (distributed::kRequestGet );
174
174
rpc_service_->ResetBarrierCounter ();
175
175
} // while(true)
176
176
}
177
177
178
+ void ListenAndServOp::ResetReceivedVars (framework::Scope *recv_scope,
179
+ platform::DeviceContext *dev_ctx,
180
+ bool reset_all) const {
181
+ for (auto &varname : sparse_vars_) {
182
+ auto var = recv_scope->FindVar (varname);
183
+ if (var == nullptr ) {
184
+ VLOG (2 ) << " can not find var " << varname << " in received scope" ;
185
+ continue ;
186
+ }
187
+ if (var->IsType <framework::SelectedRows>()) {
188
+ VLOG (3 ) << " reset sparse var: " << varname;
189
+ var->GetMutable <framework::SelectedRows>()->mutable_rows ()->clear ();
190
+ } else {
191
+ PADDLE_THROW (" The type of sparse var should be SelectedRows" );
192
+ }
193
+ }
194
+ if (UNLIKELY (reset_all)) {
195
+ for (auto &varname : dense_vars_) {
196
+ auto var = recv_scope->FindVar (varname);
197
+ if (var == nullptr ) {
198
+ VLOG (2 ) << " can not find var " << varname << " in received scope" ;
199
+ continue ;
200
+ }
201
+ if (var->IsType <framework::LoDTensor>()) {
202
+ math::set_constant (*dev_ctx, var->GetMutable <framework::LoDTensor>(),
203
+ static_cast <float >(0 ));
204
+ } else if (var->IsType <framework::Tensor>()) {
205
+ math::set_constant (*dev_ctx, var->GetMutable <framework::Tensor>(),
206
+ static_cast <float >(0 ));
207
+ } else {
208
+ PADDLE_THROW (" The type of dense var should be in [LoDTensor, Tensor]" );
209
+ }
210
+ }
211
+ }
212
+ }
213
+
178
214
void ListenAndServOp::RunAsyncLoop (framework::Executor *executor,
179
215
framework::ProgramDesc *program,
180
216
framework::Scope *recv_scope) const {
@@ -248,6 +284,25 @@ static void FillRequestCtx(
248
284
h->SetCheckpointNotifyPreparedCtx (checkpoint_ctx);
249
285
}
250
286
287
+ void ListenAndServOp::CacheVarsType (const std::vector<std::string> &varnames,
288
+ const framework::Scope &scope) const {
289
+ for (const auto &varname : varnames) {
290
+ auto var = scope.FindVar (varname);
291
+ PADDLE_ENFORCE (var != nullptr ,
292
+ " Received var should be initialized in the received scope." );
293
+ if (var->IsType <framework::SelectedRows>()) {
294
+ sparse_vars_.push_back (varname);
295
+ } else if (var->IsType <framework::LoDTensor>() ||
296
+ var->IsType <framework::Tensor>()) {
297
+ dense_vars_.push_back (varname);
298
+ } else {
299
+ PADDLE_THROW (
300
+ " The type of received var should be in [SelectedRows, LoDTensor, "
301
+ " Tensor]." );
302
+ }
303
+ }
304
+ }
305
+
251
306
void ListenAndServOp::RunImpl (const framework::Scope &scope,
252
307
const platform::Place &dev_place) const {
253
308
// Mark this as PS that it should decide profiling by listening from trainer.
@@ -258,6 +313,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
258
313
259
314
bool sync_mode = Attr<bool >(" sync_mode" );
260
315
auto fan_in = Attr<int >(" Fanin" );
316
+ auto inputs = Inputs (" X" );
261
317
262
318
PADDLE_ENFORCE (!rpc_service_);
263
319
std::string endpoint = Attr<std::string>(" endpoint" );
@@ -348,11 +404,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
348
404
signal (SIGINT, SignalHandler::StopAndExit);
349
405
signal (SIGTERM, SignalHandler::StopAndExit);
350
406
407
+ // Cache the type of the received vars as `sparse_vars_` and `dense_vars_`
408
+ // so that we can reset them at the end of each iteration.
409
+ // NOTE: only used in sync update
410
+ CacheVarsType (inputs, recv_scope);
411
+
351
412
// Write to a file of server selected port for python use.
352
413
SavePort ();
353
414
if (sync_mode) {
354
- RunSyncLoop (&executor, program, &recv_scope, prefetch_block_id_list ,
355
- checkpoint_block_id);
415
+ RunSyncLoop (&executor, program, &recv_scope, &dev_ctx ,
416
+ prefetch_block_id_list, checkpoint_block_id);
356
417
} else {
357
418
RunAsyncLoop (&executor, program, &recv_scope);
358
419
}
0 commit comments