Skip to content

Commit d117bbc

Browse files
author
Yan Xu
authored
Merge pull request #13291 from Yancey1989/reset_vars_on_pserver
reset received vars on pserver
2 parents 2fd1bf2 + 5558784 commit d117bbc

File tree

6 files changed

+91
-25
lines changed

6 files changed

+91
-25
lines changed

paddle/fluid/operators/distributed/request_handler_impl.cc

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,24 +67,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
6767
LOG(FATAL) << "sync: Can not find server side var: " << varname;
6868
return false;
6969
}
70-
71-
if (invar->IsType<framework::SelectedRows>()) {
72-
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
73-
sparse_vars_.push_back(invar);
74-
}
7570
}
7671
}
7772
return true;
7873
}
7974

80-
void RequestSendHandler::ResetSparseVarRecorder() {
81-
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
82-
for (auto* var : sparse_vars_) {
83-
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
84-
}
85-
sparse_vars_.clear();
86-
}
87-
8875
bool RequestGetHandler::Handle(const std::string& varname,
8976
framework::Scope* scope,
9077
framework::Variable* invar,

paddle/fluid/operators/distributed/request_handler_impl.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,6 @@ class RequestSendHandler final : public RequestHandler {
4141
bool Handle(const std::string& varname, framework::Scope* scope,
4242
framework::Variable* var, framework::Variable** outvar,
4343
const std::string& out_var_name = "") override;
44-
void ResetSparseVarRecorder();
45-
46-
private:
47-
std::mutex mutex_sparse_vars_;
48-
std::vector<framework::Variable*> sparse_vars_;
4944
};
5045

5146
class RequestGetHandler final : public RequestHandler {

paddle/fluid/operators/distributed/rpc_server.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ void RPCServer::Complete() {
101101
{
102102
std::unique_lock<std::mutex> lock(mutex_);
103103
client_num_--;
104+
need_reset_all_vars_ = true;
105+
104106
VLOG(4) << "decrease client_num to: " << client_num_;
105107
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
106108
barrier_counter_[kRequestGet]--;
@@ -109,6 +111,11 @@ void RPCServer::Complete() {
109111
barrier_cond_.notify_all();
110112
}
111113

114+
bool RPCServer::NeedResetAllVars() {
115+
std::unique_lock<std::mutex> lock(mutex_);
116+
return need_reset_all_vars_;
117+
}
118+
112119
int RPCServer::GetClientNum() {
113120
std::unique_lock<std::mutex> lock(mutex_);
114121
return client_num_;
@@ -120,6 +127,7 @@ void RPCServer::ResetBarrierCounter() {
120127
for (auto& t : barrier_counter_) {
121128
t.second = 0;
122129
}
130+
need_reset_all_vars_ = false;
123131
}
124132

125133
void RPCServer::RegisterRPC(const std::string& rpc_name,

paddle/fluid/operators/distributed/rpc_server.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ class RPCServer {
4949
bind_address_(address),
5050
exit_flag_(false),
5151
selected_port_(0),
52-
client_num_(client_num) {}
52+
client_num_(client_num),
53+
need_reset_all_vars_(false) {}
5354

5455
virtual ~RPCServer() {}
5556
virtual void StartServer() = 0;
@@ -86,6 +87,8 @@ class RPCServer {
8687
void ResetBarrierCounter();
8788
RPCServerProfiler& Profiler() { return profiler_; }
8889

90+
bool NeedResetAllVars();
91+
8992
protected:
9093
virtual void ShutDownImpl() = 0;
9194

@@ -104,6 +107,7 @@ class RPCServer {
104107
std::atomic<int> exit_flag_;
105108
int selected_port_;
106109
int client_num_;
110+
bool need_reset_all_vars_;
107111

108112
std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
109113
std::unordered_map<std::string, int> rpc_thread_num_;

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License. */
2222
#include "gflags/gflags.h"
2323

2424
#include "paddle/fluid/operators/detail/macros.h"
25+
#include "paddle/fluid/operators/math/math_function.h"
2526

2627
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
2728
#include "paddle/fluid/operators/listen_and_serv_op.h"
@@ -101,7 +102,7 @@ static int64_t GetTimestamp() {
101102

102103
void ListenAndServOp::RunSyncLoop(
103104
framework::Executor *executor, framework::ProgramDesc *program,
104-
framework::Scope *recv_scope,
105+
framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
105106
const std::vector<int> &prefetch_block_id_list,
106107
const int checkpoint_point_block_id) const {
107108
VLOG(2) << "RunSyncLoop";
@@ -128,6 +129,7 @@ void ListenAndServOp::RunSyncLoop(
128129
rpc_service_->SetCond(distributed::kRequestGet);
129130
rpc_service_->WaitBarrier(distributed::kRequestGet);
130131
rpc_service_->ResetBarrierCounter();
132+
131133
while (true) {
132134
rpc_service_->Profiler().OneStep();
133135
// Get from multiple trainers, we don't care about the order in which
@@ -165,16 +167,50 @@ void ListenAndServOp::RunSyncLoop(
165167
recv_scope);
166168
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
167169

168-
// reset received sparse vars to avoid reuse it in the next mini-batch
169-
dynamic_cast<distributed::RequestSendHandler *>(request_send_handler_.get())
170-
->ResetSparseVarRecorder();
170+
ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
171171

172172
rpc_service_->SetCond(distributed::kRequestGet);
173173
rpc_service_->WaitBarrier(distributed::kRequestGet);
174174
rpc_service_->ResetBarrierCounter();
175175
} // while(true)
176176
}
177177

178+
void ListenAndServOp::ResetReceivedVars(framework::Scope *recv_scope,
179+
platform::DeviceContext *dev_ctx,
180+
bool reset_all) const {
181+
for (auto &varname : sparse_vars_) {
182+
auto var = recv_scope->FindVar(varname);
183+
if (var == nullptr) {
184+
VLOG(2) << "can not find var " << varname << " in received scope";
185+
continue;
186+
}
187+
if (var->IsType<framework::SelectedRows>()) {
188+
VLOG(3) << "reset sparse var: " << varname;
189+
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
190+
} else {
191+
PADDLE_THROW("The type of sparse var should be SelectedRows");
192+
}
193+
}
194+
if (UNLIKELY(reset_all)) {
195+
for (auto &varname : dense_vars_) {
196+
auto var = recv_scope->FindVar(varname);
197+
if (var == nullptr) {
198+
VLOG(2) << "can not find var " << varname << " in received scope";
199+
continue;
200+
}
201+
if (var->IsType<framework::LoDTensor>()) {
202+
math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(),
203+
static_cast<float>(0));
204+
} else if (var->IsType<framework::Tensor>()) {
205+
math::set_constant(*dev_ctx, var->GetMutable<framework::Tensor>(),
206+
static_cast<float>(0));
207+
} else {
208+
PADDLE_THROW("The type of dense var should be in [LoDTensor, Tensor]");
209+
}
210+
}
211+
}
212+
}
213+
178214
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
179215
framework::ProgramDesc *program,
180216
framework::Scope *recv_scope) const {
@@ -248,6 +284,25 @@ static void FillRequestCtx(
248284
h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx);
249285
}
250286

287+
void ListenAndServOp::CacheVarsType(const std::vector<std::string> &varnames,
288+
const framework::Scope &scope) const {
289+
for (const auto &varname : varnames) {
290+
auto var = scope.FindVar(varname);
291+
PADDLE_ENFORCE(var != nullptr,
292+
"Received var should be initialized in the received scope.");
293+
if (var->IsType<framework::SelectedRows>()) {
294+
sparse_vars_.push_back(varname);
295+
} else if (var->IsType<framework::LoDTensor>() ||
296+
var->IsType<framework::Tensor>()) {
297+
dense_vars_.push_back(varname);
298+
} else {
299+
PADDLE_THROW(
300+
"The type of received var should be in [SelectedRows, LoDTensor, "
301+
"Tensor].");
302+
}
303+
}
304+
}
305+
251306
void ListenAndServOp::RunImpl(const framework::Scope &scope,
252307
const platform::Place &dev_place) const {
253308
// Mark this as PS that it should decide profiling by listening from trainer.
@@ -258,6 +313,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
258313

259314
bool sync_mode = Attr<bool>("sync_mode");
260315
auto fan_in = Attr<int>("Fanin");
316+
auto inputs = Inputs("X");
261317

262318
PADDLE_ENFORCE(!rpc_service_);
263319
std::string endpoint = Attr<std::string>("endpoint");
@@ -348,11 +404,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
348404
signal(SIGINT, SignalHandler::StopAndExit);
349405
signal(SIGTERM, SignalHandler::StopAndExit);
350406

407+
// Cache the type of the received vars as `sparse_vars_` and `dense_vars_`
408+
// so that we can reset them at the end of each iteration.
409+
// NOTE: only used in sync update
410+
CacheVarsType(inputs, recv_scope);
411+
351412
// Write to a file of server selected port for python use.
352413
SavePort();
353414
if (sync_mode) {
354-
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list,
355-
checkpoint_block_id);
415+
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
416+
prefetch_block_id_list, checkpoint_block_id);
356417
} else {
357418
RunAsyncLoop(&executor, program, &recv_scope);
358419
}

paddle/fluid/operators/listen_and_serv_op.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License. */
2626
#include "paddle/fluid/framework/threadpool.h"
2727
#include "paddle/fluid/operators/distributed/request_handler.h"
2828
#include "paddle/fluid/operators/distributed/rpc_server.h"
29+
#include "paddle/fluid/platform/device_context.h"
2930

3031
namespace paddle {
3132
namespace operators {
@@ -48,6 +49,7 @@ class ListenAndServOp : public framework::OperatorBase {
4849
void RunSyncLoop(framework::Executor* executor,
4950
framework::ProgramDesc* program,
5051
framework::Scope* recv_scope,
52+
platform::DeviceContext* dev_ctx,
5153
const std::vector<int>& prefetch_block_id_list,
5254
const int checkpoint_point_block_id) const;
5355

@@ -64,6 +66,13 @@ class ListenAndServOp : public framework::OperatorBase {
6466
void RunImpl(const framework::Scope& scope,
6567
const platform::Place& dev_place) const override;
6668

69+
void ResetReceivedVars(framework::Scope* recv_scope,
70+
platform::DeviceContext* dev_ctx,
71+
bool reset_all = false) const;
72+
73+
void CacheVarsType(const std::vector<std::string>& varnames,
74+
const framework::Scope& scope) const;
75+
6776
protected:
6877
mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
6978
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
@@ -74,6 +83,8 @@ class ListenAndServOp : public framework::OperatorBase {
7483
request_checkpoint_handler_;
7584

7685
mutable std::shared_ptr<std::thread> server_thread_;
86+
mutable std::vector<std::string> sparse_vars_;
87+
mutable std::vector<std::string> dense_vars_;
7788
};
7889

7990
class SignalHandler {

0 commit comments

Comments
 (0)