Skip to content

Commit f25abba

Browse files
authored
Merge pull request #1 from Yancey1989/panxin_polish_sparse
polish sparse update code
2 parents e0895e4 + 1239fce commit f25abba

File tree

4 files changed

+23
-0
lines changed

4 files changed

+23
-0
lines changed

paddle/fluid/operators/detail/request_handler_impl.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ bool RequestSendHandler::Handle(const std::string& varname,
6363
PADDLE_THROW("sync: Can not find server side var");
6464
return false;
6565
}
66+
if (invar->IsType<framework::SelectedRows>()) {
67+
rpc_server_->RecordSparseVar(invar);
68+
}
6669
}
6770

6871
return true;

paddle/fluid/operators/detail/rpc_server.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,19 @@ void RPCServer::ResetBarrierCounter() {
7373
t.second = 0;
7474
}
7575
}
76+
void RPCServer::RecordSparseVar(framework::Variable* sparse_var) {
77+
std::unique_lock<std::mutex> lock(mutex_sparse_var_recorder_);
78+
sparse_vars_.push_back(sparse_var);
79+
}
80+
81+
void RPCServer::ResetSparseVarsRecorder() {
82+
VLOG(3) << "RPCServer reset sparse vars recorder.";
83+
std::unique_lock<std::mutex> lock(mutex_sparse_var_recorder_);
84+
for (auto* var : sparse_vars_) {
85+
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
86+
}
87+
sparse_vars_.clear();
88+
}
7689

7790
void RPCServer::RegisterRPC(const std::string& rpc_name,
7891
RequestHandler* handler, int thread_num) {

paddle/fluid/operators/detail/rpc_server.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ class RPCServer {
6060
void SetCond(const std::string& rpc_name);
6161
void WaitCond(const std::string& rpc_name);
6262
void IncreaseBatchBarrier(const std::string rpc_name);
63+
6364
void ResetBarrierCounter();
65+
void RecordSparseVar(framework::Variable* sparse_var);
66+
void ResetSparseVarsRecorder();
6467

6568
protected:
6669
virtual void ShutDownImpl() = 0;
@@ -74,6 +77,9 @@ class RPCServer {
7477
std::atomic<int> cur_cond_;
7578
std::condition_variable rpc_cond_;
7679

80+
std::vector<framework::Variable*> sparse_vars_;
81+
std::mutex mutex_sparse_var_recorder_;
82+
7783
protected:
7884
std::string bind_address_;
7985
std::atomic<int> exit_flag_;

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
146146
rpc_service_->SetCond(detail::kRequestGet);
147147
rpc_service_->WaitBarrier(detail::kRequestGet);
148148
rpc_service_->ResetBarrierCounter();
149+
rpc_service_->ResetSparseVarsRecorder();
149150
} // while(true)
150151
}
151152

0 commit comments

Comments
 (0)