Skip to content

Commit 32b94a7

Browse files
committed
cache var types
1 parent e5a9353 commit 32b94a7

File tree

2 files changed

+50
-17
lines changed

2 files changed

+50
-17
lines changed

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ void ListenAndServOp::RunSyncLoop(
104104
framework::Executor *executor, framework::ProgramDesc *program,
105105
framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
106106
const std::vector<int> &prefetch_block_id_list,
107-
const int checkpoint_point_block_id,
108-
const std::vector<std::string> &recv_varnames) const {
107+
const int checkpoint_point_block_id) const {
109108
VLOG(2) << "RunSyncLoop";
110109
size_t num_blocks = program->Size();
111110
auto optimize_blocks =
@@ -130,6 +129,7 @@ void ListenAndServOp::RunSyncLoop(
130129
rpc_service_->SetCond(distributed::kRequestGet);
131130
rpc_service_->WaitBarrier(distributed::kRequestGet);
132131
rpc_service_->ResetBarrierCounter();
132+
133133
while (true) {
134134
rpc_service_->Profiler().OneStep();
135135
// Get from multiple trainers, we don't care about the order in which
@@ -167,19 +167,18 @@ void ListenAndServOp::RunSyncLoop(
167167
recv_scope);
168168
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
169169

170-
ResetReceivedVars(recv_varnames, recv_scope, dev_ctx,
171-
rpc_service_->NeedResetAllVars());
170+
ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
172171

173172
rpc_service_->SetCond(distributed::kRequestGet);
174173
rpc_service_->WaitBarrier(distributed::kRequestGet);
175174
rpc_service_->ResetBarrierCounter();
176175
} // while(true)
177176
}
178177

179-
void ListenAndServOp::ResetReceivedVars(
180-
const std::vector<std::string> &recv_varnames, framework::Scope *recv_scope,
181-
platform::DeviceContext *dev_ctx, bool reset_all) const {
182-
for (auto &varname : recv_varnames) {
178+
void ListenAndServOp::ResetReceivedVars(framework::Scope *recv_scope,
179+
platform::DeviceContext *dev_ctx,
180+
bool reset_all) const {
181+
for (auto &varname : sparse_vars_) {
183182
auto var = recv_scope->FindVar(varname);
184183
if (var == nullptr) {
185184
VLOG(2) << "can not find var " << varname << " in received scope";
@@ -188,18 +187,25 @@ void ListenAndServOp::ResetReceivedVars(
188187
if (var->IsType<framework::SelectedRows>()) {
189188
VLOG(3) << "reset sparse var: " << varname;
190189
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
190+
} else {
191+
PADDLE_THROW("The type of sparse var should be SelectedRows");
191192
}
192-
if (UNLIKELY(reset_all)) {
193-
VLOG(3) << "reset dense var: " << varname;
193+
}
194+
if (UNLIKELY(reset_all)) {
195+
for (auto &varname : dense_vars_) {
196+
auto var = recv_scope->FindVar(varname);
197+
if (var == nullptr) {
198+
VLOG(2) << "can not find var " << varname << " in received scope";
199+
continue;
200+
}
194201
if (var->IsType<framework::LoDTensor>()) {
195202
math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(),
196203
static_cast<float>(0));
197204
} else if (var->IsType<framework::Tensor>()) {
198205
math::set_constant(*dev_ctx, var->GetMutable<framework::Tensor>(),
199206
static_cast<float>(0));
200207
} else {
201-
PADDLE_THROW(
202-
"received var should be in [SelectedRows, LoDTensor, Tensor]");
208+
PADDLE_THROW("The type of dense var should be in [LoDTensor, Tensor]");
203209
}
204210
}
205211
}
@@ -278,6 +284,25 @@ static void FillRequestCtx(
278284
h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx);
279285
}
280286

287+
void ListenAndServOp::CacheVarsType(const std::vector<std::string> &varnames,
288+
const framework::Scope &scope) const {
289+
for (const auto &varname : varnames) {
290+
auto var = scope.FindVar(varname);
291+
PADDLE_ENFORCE(var != nullptr,
292+
"Received var should be initialized in the received scope.");
293+
if (var->IsType<framework::SelectedRows>()) {
294+
sparse_vars_.push_back(varname);
295+
} else if (var->IsType<framework::LoDTensor>() ||
296+
var->IsType<framework::Tensor>()) {
297+
dense_vars_.push_back(varname);
298+
} else {
299+
PADDLE_THROW(
300+
"The type of received var should be in [SelectedRows, LoDTensor, "
301+
"Tensor].");
302+
}
303+
}
304+
}
305+
281306
void ListenAndServOp::RunImpl(const framework::Scope &scope,
282307
const platform::Place &dev_place) const {
283308
// Mark this as PS that it should decide profiling by listening from trainer.
@@ -379,11 +404,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
379404
signal(SIGINT, SignalHandler::StopAndExit);
380405
signal(SIGTERM, SignalHandler::StopAndExit);
381406

407+
// Cache the type of the received vars as `sparse_vars_` and `dense_vars_`
408+
// so that we can reset them at the end of each iteration.
409+
// NOTE: only used in sync update
410+
CacheVarsType(inputs, recv_scope);
411+
382412
// Write to a file of server selected port for python use.
383413
SavePort();
384414
if (sync_mode) {
385415
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
386-
prefetch_block_id_list, checkpoint_block_id, inputs);
416+
prefetch_block_id_list, checkpoint_block_id);
387417
} else {
388418
RunAsyncLoop(&executor, program, &recv_scope);
389419
}

paddle/fluid/operators/listen_and_serv_op.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ class ListenAndServOp : public framework::OperatorBase {
5151
framework::Scope* recv_scope,
5252
platform::DeviceContext* dev_ctx,
5353
const std::vector<int>& prefetch_block_id_list,
54-
const int checkpoint_point_block_id,
55-
const std::vector<std::string>& recv_varnames) const;
54+
const int checkpoint_point_block_id) const;
5655

5756
void RunAsyncLoop(framework::Executor* executor,
5857
framework::ProgramDesc* program,
@@ -67,11 +66,13 @@ class ListenAndServOp : public framework::OperatorBase {
6766
void RunImpl(const framework::Scope& scope,
6867
const platform::Place& dev_place) const override;
6968

70-
void ResetReceivedVars(const std::vector<std::string>& recv_varnames,
71-
framework::Scope* recv_scope,
69+
void ResetReceivedVars(framework::Scope* recv_scope,
7270
platform::DeviceContext* dev_ctx,
7371
bool reset_all = false) const;
7472

73+
void CacheVarsType(const std::vector<std::string>& varnames,
74+
const framework::Scope& scope) const;
75+
7576
protected:
7677
mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
7778
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
@@ -82,6 +83,8 @@ class ListenAndServOp : public framework::OperatorBase {
8283
request_checkpoint_handler_;
8384

8485
mutable std::shared_ptr<std::thread> server_thread_;
86+
mutable std::vector<std::string> sparse_vars_;
87+
mutable std::vector<std::string> dense_vars_;
8588
};
8689

8790
class SignalHandler {

0 commit comments

Comments
 (0)