Skip to content

Commit 6edfae4

Browse files
committed
reset received vars on pserver
1 parent f76f42c commit 6edfae4

File tree

6 files changed

+58
-26
lines changed

6 files changed

+58
-26
lines changed

paddle/fluid/operators/distributed/request_handler_impl.cc

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,24 +67,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
6767
LOG(FATAL) << "sync: Can not find server side var: " << varname;
6868
return false;
6969
}
70-
71-
if (invar->IsType<framework::SelectedRows>()) {
72-
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
73-
sparse_vars_.push_back(invar);
74-
}
7570
}
7671
}
7772
return true;
7873
}
7974

80-
void RequestSendHandler::ResetSparseVarRecorder() {
81-
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
82-
for (auto* var : sparse_vars_) {
83-
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
84-
}
85-
sparse_vars_.clear();
86-
}
87-
8875
bool RequestGetHandler::Handle(const std::string& varname,
8976
framework::Scope* scope,
9077
framework::Variable* invar,

paddle/fluid/operators/distributed/request_handler_impl.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,6 @@ class RequestSendHandler final : public RequestHandler {
4141
bool Handle(const std::string& varname, framework::Scope* scope,
4242
framework::Variable* var, framework::Variable** outvar,
4343
const std::string& out_var_name = "") override;
44-
void ResetSparseVarRecorder();
45-
46-
private:
47-
std::mutex mutex_sparse_vars_;
48-
std::vector<framework::Variable*> sparse_vars_;
4944
};
5045

5146
class RequestGetHandler final : public RequestHandler {

paddle/fluid/operators/distributed/rpc_server.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ void RPCServer::Complete() {
101101
{
102102
std::unique_lock<std::mutex> lock(mutex_);
103103
client_num_--;
104+
need_reset_all_vars_ = true;
105+
104106
VLOG(4) << "decrease client_num to: " << client_num_;
105107
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
106108
barrier_counter_[kRequestGet]--;
@@ -109,6 +111,11 @@ void RPCServer::Complete() {
109111
barrier_cond_.notify_all();
110112
}
111113

114+
bool RPCServer::NeedResetAllVars() {
115+
std::unique_lock<std::mutex> lock(mutex_);
116+
return need_reset_all_vars_;
117+
}
118+
112119
int RPCServer::GetClientNum() {
113120
std::unique_lock<std::mutex> lock(mutex_);
114121
return client_num_;
@@ -120,6 +127,7 @@ void RPCServer::ResetBarrierCounter() {
120127
for (auto& t : barrier_counter_) {
121128
t.second = 0;
122129
}
130+
need_reset_all_vars_ = false;
123131
}
124132

125133
void RPCServer::RegisterRPC(const std::string& rpc_name,

paddle/fluid/operators/distributed/rpc_server.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ class RPCServer {
4949
bind_address_(address),
5050
exit_flag_(false),
5151
selected_port_(0),
52-
client_num_(client_num) {}
52+
client_num_(client_num),
53+
need_reset_all_vars_(false) {}
5354

5455
virtual ~RPCServer() {}
5556
virtual void StartServer() = 0;
@@ -86,6 +87,8 @@ class RPCServer {
8687
void ResetBarrierCounter();
8788
RPCServerProfiler& Profiler() { return profiler_; }
8889

90+
bool NeedResetAllVars();
91+
8992
protected:
9093
virtual void ShutDownImpl() = 0;
9194

@@ -104,6 +107,7 @@ class RPCServer {
104107
std::atomic<int> exit_flag_;
105108
int selected_port_;
106109
int client_num_;
110+
bool need_reset_all_vars_;
107111

108112
std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
109113
std::unordered_map<std::string, int> rpc_thread_num_;

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License. */
2222
#include "gflags/gflags.h"
2323

2424
#include "paddle/fluid/operators/detail/macros.h"
25+
#include "paddle/fluid/operators/math/math_function.h"
2526

2627
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
2728
#include "paddle/fluid/operators/listen_and_serv_op.h"
@@ -101,9 +102,10 @@ static int64_t GetTimestamp() {
101102

102103
void ListenAndServOp::RunSyncLoop(
103104
framework::Executor *executor, framework::ProgramDesc *program,
104-
framework::Scope *recv_scope,
105+
framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
105106
const std::vector<int> &prefetch_block_id_list,
106-
const int checkpoint_point_block_id) const {
107+
const int checkpoint_point_block_id,
108+
const std::vector<std::string> &recv_varnames) const {
107109
VLOG(2) << "RunSyncLoop";
108110
size_t num_blocks = program->Size();
109111
auto optimize_blocks =
@@ -166,15 +168,42 @@ void ListenAndServOp::RunSyncLoop(
166168
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
167169

168170
// reset received sparse vars to avoid reuse it in the next mini-batch
169-
dynamic_cast<distributed::RequestSendHandler *>(request_send_handler_.get())
170-
->ResetSparseVarRecorder();
171+
ResetReceivedVars(recv_varnames, recv_scope, dev_ctx,
172+
!rpc_service_->NeedResetAllVars());
171173

172174
rpc_service_->SetCond(distributed::kRequestGet);
173175
rpc_service_->WaitBarrier(distributed::kRequestGet);
174176
rpc_service_->ResetBarrierCounter();
175177
} // while(true)
176178
}
177179

180+
void ListenAndServOp::ResetReceivedVars(
181+
const std::vector<std::string> &recv_varnames, framework::Scope *recv_scope,
182+
platform::DeviceContext *dev_ctx, bool only_sparse_vars) const {
183+
for (auto &varname : recv_varnames) {
184+
auto var = recv_scope->FindVar(varname);
185+
if (var == nullptr) {
186+
VLOG(2) << "can not find var " << varname << " in received scope";
187+
continue;
188+
}
189+
if (var->IsType<framework::SelectedRows>()) {
190+
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
191+
}
192+
if (!only_sparse_vars) {
193+
if (var->IsType<framework::LoDTensor>()) {
194+
math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(),
195+
static_cast<float>(0));
196+
} else if (var->IsType<framework::Tensor>()) {
197+
math::set_constant(*dev_ctx, var->GetMutable<framework::Tensor>(),
198+
static_cast<float>(0));
199+
} else {
200+
PADDLE_THROW(
201+
"received var should be in [SelectedRows, LoDTensor, Tensor]");
202+
}
203+
}
204+
}
205+
}
206+
178207
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
179208
framework::ProgramDesc *program,
180209
framework::Scope *recv_scope) const {
@@ -258,6 +287,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
258287

259288
bool sync_mode = Attr<bool>("sync_mode");
260289
auto fan_in = Attr<int>("Fanin");
290+
auto inputs = Inputs("X");
261291

262292
PADDLE_ENFORCE(!rpc_service_);
263293
std::string endpoint = Attr<std::string>("endpoint");
@@ -351,8 +381,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
351381
// Write to a file of server selected port for python use.
352382
SavePort();
353383
if (sync_mode) {
354-
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list,
355-
checkpoint_block_id);
384+
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
385+
prefetch_block_id_list, checkpoint_block_id, inputs);
356386
} else {
357387
RunAsyncLoop(&executor, program, &recv_scope);
358388
}

paddle/fluid/operators/listen_and_serv_op.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License. */
2626
#include "paddle/fluid/framework/threadpool.h"
2727
#include "paddle/fluid/operators/distributed/request_handler.h"
2828
#include "paddle/fluid/operators/distributed/rpc_server.h"
29+
#include "paddle/fluid/platform/device_context.h"
2930

3031
namespace paddle {
3132
namespace operators {
@@ -48,8 +49,10 @@ class ListenAndServOp : public framework::OperatorBase {
4849
void RunSyncLoop(framework::Executor* executor,
4950
framework::ProgramDesc* program,
5051
framework::Scope* recv_scope,
52+
platform::DeviceContext* dev_ctx,
5153
const std::vector<int>& prefetch_block_id_list,
52-
const int checkpoint_point_block_id) const;
54+
const int checkpoint_point_block_id,
55+
const std::vector<std::string>& recv_varnames) const;
5356

5457
void RunAsyncLoop(framework::Executor* executor,
5558
framework::ProgramDesc* program,
@@ -64,6 +67,11 @@ class ListenAndServOp : public framework::OperatorBase {
6467
void RunImpl(const framework::Scope& scope,
6568
const platform::Place& dev_place) const override;
6669

70+
void ResetReceivedVars(const std::vector<std::string>& recv_varnames,
71+
framework::Scope* recv_scope,
72+
platform::DeviceContext* dev_ctx,
73+
bool only_sparse_vars = true) const;
74+
6775
protected:
6876
mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
6977
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;

0 commit comments

Comments
 (0)