Skip to content

Commit 259e63d

Browse files
authored
Merge pull request #11248 from panyx0718/dist
Fix sparse vars usage for dist train
2 parents 2d7c836 + f25abba commit 259e63d

File tree

5 files changed

+21
-22
lines changed

5 files changed

+21
-22
lines changed

paddle/fluid/operators/detail/request_handler.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ class RequestHandler {
8080
}
8181
framework::ProgramDesc* program() { return program_; }
8282
framework::Executor* executor() { return executor_; }
83-
std::vector<framework::Variable*>& sparse_vars() { return sparse_vars_; }
8483

8584
// This function processes user's rpc request.
8685
// The implemention is in request_handler_impl.
@@ -113,13 +112,7 @@ class RequestHandler {
113112
std::unordered_map<std::string,
114113
std::shared_ptr<framework::ExecutorPrepareContext>>*
115114
grad_to_prepared_ctx_;
116-
117-
// Record received sparse variables, so that
118-
// we could reset those after execute optimize program
119-
std::vector<framework::Variable*> sparse_vars_;
120115
RPCServer* rpc_server_;
121-
122-
std::mutex sparse_var_mutex_;
123116
};
124117

125118
} // namespace detail

paddle/fluid/operators/detail/request_handler_impl.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
6363
PADDLE_THROW("sync: Can not find server side var");
6464
return false;
6565
}
66-
6766
if (invar->IsType<framework::SelectedRows>()) {
68-
std::unique_lock<std::mutex> lock(sparse_var_mutex_);
69-
sparse_vars_.push_back(invar);
67+
rpc_server_->RecordSparseVar(invar);
7068
}
7169
}
7270

paddle/fluid/operators/detail/rpc_server.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,19 @@ void RPCServer::ResetBarrierCounter() {
7373
t.second = 0;
7474
}
7575
}
76+
void RPCServer::RecordSparseVar(framework::Variable* sparse_var) {
77+
std::unique_lock<std::mutex> lock(mutex_sparse_var_recorder_);
78+
sparse_vars_.push_back(sparse_var);
79+
}
80+
81+
void RPCServer::ResetSparseVarsRecorder() {
82+
VLOG(3) << "RPCServer reset sparse vars recorder.";
83+
std::unique_lock<std::mutex> lock(mutex_sparse_var_recorder_);
84+
for (auto* var : sparse_vars_) {
85+
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
86+
}
87+
sparse_vars_.clear();
88+
}
7689

7790
void RPCServer::RegisterRPC(const std::string& rpc_name,
7891
RequestHandler* handler, int thread_num) {

paddle/fluid/operators/detail/rpc_server.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ class RPCServer {
6060
void SetCond(const std::string& rpc_name);
6161
void WaitCond(const std::string& rpc_name);
6262
void IncreaseBatchBarrier(const std::string rpc_name);
63+
6364
void ResetBarrierCounter();
65+
void RecordSparseVar(framework::Variable* sparse_var);
66+
void ResetSparseVarsRecorder();
6467

6568
protected:
6669
virtual void ShutDownImpl() = 0;
@@ -74,6 +77,9 @@ class RPCServer {
7477
std::atomic<int> cur_cond_;
7578
std::condition_variable rpc_cond_;
7679

80+
std::vector<framework::Variable*> sparse_vars_;
81+
std::mutex mutex_sparse_var_recorder_;
82+
7783
protected:
7884
std::string bind_address_;
7985
std::atomic<int> exit_flag_;

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,6 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
108108
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
109109

110110
rpc_service_->ResetBarrierCounter();
111-
// Record received sparse variables, so that
112-
// we could reset those after execute optimize program
113-
std::vector<framework::Variable *> sparse_vars;
114111
while (true) {
115112
// Get from multiple trainers, we don't care about the order in which
116113
// the gradients arrives, just add suffix 0~n and merge the gradient.
@@ -146,18 +143,10 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
146143
recv_scope);
147144
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
148145

149-
// Reset the received sparse variables, the sum operator would not
150-
// sum the input sparse variables which rows is empty at the next
151-
// mini-batch.
152-
// TODO(Yancey1989): move the reset action into an operator, we couldn't
153-
// have any hide logic in the operator.
154-
for (framework::Variable *var : sparse_vars) {
155-
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
156-
}
157-
158146
rpc_service_->SetCond(detail::kRequestGet);
159147
rpc_service_->WaitBarrier(detail::kRequestGet);
160148
rpc_service_->ResetBarrierCounter();
149+
rpc_service_->ResetSparseVarsRecorder();
161150
} // while(true)
162151
}
163152

0 commit comments

Comments
 (0)