Skip to content

Commit 7bcc980

Browse files
authored
Merge pull request #11321 from Yancey1989/polish_sparse_update
polish sparse update logic
2 parents eced973 + 5696494 commit 7bcc980

File tree

5 files changed

+18
-21
lines changed

5 files changed

+18
-21
lines changed

paddle/fluid/operators/detail/request_handler_impl.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,21 @@ bool RequestSendHandler::Handle(const std::string& varname,
6464
return false;
6565
}
6666
if (invar->IsType<framework::SelectedRows>()) {
67-
rpc_server_->RecordSparseVar(invar);
67+
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
68+
sparse_vars_.push_back(invar);
6869
}
6970
}
70-
7171
return true;
7272
}
7373

74+
void RequestSendHandler::ResetSparseVarRecorder() {
75+
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
76+
for (auto* var : sparse_vars_) {
77+
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
78+
}
79+
sparse_vars_.clear();
80+
}
81+
7482
bool RequestGetHandler::Handle(const std::string& varname,
7583
framework::Scope* scope,
7684
framework::Variable* invar,

paddle/fluid/operators/detail/request_handler_impl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ class RequestSendHandler final : public RequestHandler {
4141
virtual ~RequestSendHandler() {}
4242
bool Handle(const std::string& varname, framework::Scope* scope,
4343
framework::Variable* var, framework::Variable** outvar) override;
44+
void ResetSparseVarRecorder();
45+
46+
private:
47+
std::mutex mutex_sparse_vars_;
48+
std::vector<framework::Variable*> sparse_vars_;
4449
};
4550

4651
class RequestGetHandler final : public RequestHandler {

paddle/fluid/operators/detail/rpc_server.cc

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,19 +73,6 @@ void RPCServer::ResetBarrierCounter() {
7373
t.second = 0;
7474
}
7575
}
76-
void RPCServer::RecordSparseVar(framework::Variable* sparse_var) {
77-
std::unique_lock<std::mutex> lock(mutex_sparse_var_recorder_);
78-
sparse_vars_.push_back(sparse_var);
79-
}
80-
81-
void RPCServer::ResetSparseVarsRecorder() {
82-
VLOG(3) << "RPCServer reset sparse vars recorder.";
83-
std::unique_lock<std::mutex> lock(mutex_sparse_var_recorder_);
84-
for (auto* var : sparse_vars_) {
85-
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
86-
}
87-
sparse_vars_.clear();
88-
}
8976

9077
void RPCServer::RegisterRPC(const std::string& rpc_name,
9178
RequestHandler* handler, int thread_num) {

paddle/fluid/operators/detail/rpc_server.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ class RPCServer {
6262
void IncreaseBatchBarrier(const std::string rpc_name);
6363

6464
void ResetBarrierCounter();
65-
void RecordSparseVar(framework::Variable* sparse_var);
66-
void ResetSparseVarsRecorder();
6765

6866
protected:
6967
virtual void ShutDownImpl() = 0;
@@ -77,9 +75,6 @@ class RPCServer {
7775
std::atomic<int> cur_cond_;
7876
std::condition_variable rpc_cond_;
7977

80-
std::vector<framework::Variable*> sparse_vars_;
81-
std::mutex mutex_sparse_var_recorder_;
82-
8378
protected:
8479
std::string bind_address_;
8580
std::atomic<int> exit_flag_;

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
146146
rpc_service_->SetCond(detail::kRequestGet);
147147
rpc_service_->WaitBarrier(detail::kRequestGet);
148148
rpc_service_->ResetBarrierCounter();
149-
rpc_service_->ResetSparseVarsRecorder();
149+
// reset received sparse vars to avoid reuse it in the next mini-batch
150+
dynamic_cast<detail::RequestSendHandler *>(request_send_handler_.get())
151+
->ResetSparseVarRecorder();
150152
} // while(true)
151153
}
152154

0 commit comments

Comments
 (0)