Skip to content

Commit 86e09b3

Browse files
committed
fix asyn update error on pserver
1 parent 6d6996a commit 86e09b3

File tree

3 files changed

+13
-14
lines changed

3 files changed

+13
-14
lines changed

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ void ListenAndServOp::RunSyncLoop(
163163
}
164164

165165
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
166-
framework::ProgramDesc *program) const {
166+
framework::ProgramDesc *program,
167+
framework::Scope *recv_scope) const {
167168
VLOG(3) << "RunAsyncLoop in";
168169
// grad name to block id
169170
std::unordered_map<std::string, int32_t> grad_to_block_id;
@@ -191,6 +192,10 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
191192
block_list.push_back(blkid);
192193
}
193194
auto optimize_prepared = executor->Prepare(*program, block_list);
195+
// execute global block if needed
196+
if (block_list[0] == 1 && id_to_grad.count(1) == 0) {
197+
executor->RunPreparedContext(optimize_prepared[0].get(), recv_scope);
198+
}
194199
std::unordered_map<std::string,
195200
std::shared_ptr<framework::ExecutorPrepareContext>>
196201
grad_to_prepared_ctx;
@@ -319,7 +324,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
319324
if (sync_mode) {
320325
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list);
321326
} else {
322-
RunAsyncLoop(&executor, program);
327+
RunAsyncLoop(&executor, program, &recv_scope);
323328
}
324329
}
325330

paddle/fluid/operators/listen_and_serv_op.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ class ListenAndServOp : public framework::OperatorBase {
5050
const std::vector<int>& prefetch_block_id_list) const;
5151

5252
void RunAsyncLoop(framework::Executor* executor,
53-
framework::ProgramDesc* program) const;
53+
framework::ProgramDesc* program,
54+
framework::Scope* recv_scope) const;
5455

5556
void SavePort() const;
5657

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,16 +1299,6 @@ def _create_ufind(self, optimize_ops):
12991299
ufind.union(op1, op2)
13001300
return ufind
13011301

1302-
def _is_opt_role_op(self, op):
1303-
# NOTE: depend on oprole to find out whether this op is for
1304-
# optimize
1305-
op_maker = core.op_proto_and_checker_maker
1306-
optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
1307-
if op_maker.kOpRoleAttrName() in op.attrs and \
1308-
int(op.attrs[op_maker.kOpRoleAttrName()]) == int(optimize_role):
1309-
return True
1310-
return False
1311-
13121302
def _is_optimizer_op(self, op):
13131303
if "Param" in op.input_names and \
13141304
"LearningRate" in op.input_names:
@@ -1399,7 +1389,10 @@ def _get_optimize_pass(self):
13991389
params_grads = []
14001390
origin_var_dict = self.origin_program.global_block().vars
14011391
for op in block.ops:
1402-
if self._is_opt_role_op(op):
1392+
# NOTE(Yancey1989): we can not use op role to distinguish an optimizer op
1393+
# or not, because all ops in optimizer sub-graph would
1394+
# sign the optimizer op role
1395+
if self._is_optimizer_op(op):
14031396
opt_ops.append(op)
14041397
# HACK(wuyi): if we find grad vars from input of optimize
14051398
# ops, we may get the output of clip op. Use syntax "@GRAD"

0 commit comments

Comments
 (0)