Skip to content

Commit 580f55f

Browse files
committed
update by comment
1 parent 6edfae4 commit 580f55f

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,8 @@ void ListenAndServOp::RunSyncLoop(
167167
recv_scope);
168168
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
169169

170-
// reset received sparse vars to avoid reuse it in the next mini-batch
171170
ResetReceivedVars(recv_varnames, recv_scope, dev_ctx,
172-
!rpc_service_->NeedResetAllVars());
171+
rpc_service_->NeedResetAllVars());
173172

174173
rpc_service_->SetCond(distributed::kRequestGet);
175174
rpc_service_->WaitBarrier(distributed::kRequestGet);
@@ -179,17 +178,19 @@ void ListenAndServOp::RunSyncLoop(
179178

180179
void ListenAndServOp::ResetReceivedVars(
181180
const std::vector<std::string> &recv_varnames, framework::Scope *recv_scope,
182-
platform::DeviceContext *dev_ctx, bool only_sparse_vars) const {
181+
platform::DeviceContext *dev_ctx, bool reset_all) const {
183182
for (auto &varname : recv_varnames) {
184183
auto var = recv_scope->FindVar(varname);
185184
if (var == nullptr) {
186185
VLOG(2) << "can not find var " << varname << " in received scope";
187186
continue;
188187
}
189188
if (var->IsType<framework::SelectedRows>()) {
189+
VLOG(3) << "reset sparse var: " << varname;
190190
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
191191
}
192-
if (!only_sparse_vars) {
192+
if (UNLIKELY(reset_all)) {
193+
VLOG(3) << "reset dense var: " << varname;
193194
if (var->IsType<framework::LoDTensor>()) {
194195
math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(),
195196
static_cast<float>(0));

paddle/fluid/operators/listen_and_serv_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class ListenAndServOp : public framework::OperatorBase {
7070
void ResetReceivedVars(const std::vector<std::string>& recv_varnames,
7171
framework::Scope* recv_scope,
7272
platform::DeviceContext* dev_ctx,
73-
bool only_sparse_vars = true) const;
73+
bool reset_all = false) const;
7474

7575
protected:
7676
mutable std::shared_ptr<distributed::RPCServer> rpc_service_;

0 commit comments

Comments
 (0)