Skip to content

Commit 6675ba2

Browse files
author
Yancey
authored
Merge pull request #11736 from Yancey1989/fix_async_update
[cherry-pick] Fix async update
2 parents be0db33 + 4bb612e commit 6675ba2

File tree

3 files changed

+13
-14
lines changed

3 files changed

+13
-14
lines changed

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ void ListenAndServOp::RunSyncLoop(
163163
}
164164

165165
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
166-
framework::ProgramDesc *program) const {
166+
framework::ProgramDesc *program,
167+
framework::Scope *recv_scope) const {
167168
// grad name to block id
168169
std::unordered_map<std::string, int32_t> grad_to_block_id;
169170
std::unordered_map<int32_t, std::string> id_to_grad;
@@ -190,6 +191,10 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
190191
block_list.push_back(blkid);
191192
}
192193
auto optimize_prepared = executor->Prepare(*program, block_list);
194+
// execute global block if needed
195+
if (block_list[0] == 1 && id_to_grad.count(1) == 0) {
196+
executor->RunPreparedContext(optimize_prepared[0].get(), recv_scope);
197+
}
193198
std::unordered_map<std::string,
194199
std::shared_ptr<framework::ExecutorPrepareContext>>
195200
grad_to_prepared_ctx;
@@ -317,7 +322,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
317322
if (sync_mode) {
318323
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list);
319324
} else {
320-
RunAsyncLoop(&executor, program);
325+
RunAsyncLoop(&executor, program, &recv_scope);
321326
}
322327
}
323328

paddle/fluid/operators/listen_and_serv_op.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ class ListenAndServOp : public framework::OperatorBase {
5050
const std::vector<int>& prefetch_block_id_list) const;
5151

5252
void RunAsyncLoop(framework::Executor* executor,
53-
framework::ProgramDesc* program) const;
53+
framework::ProgramDesc* program,
54+
framework::Scope* recv_scope) const;
5455

5556
void SavePort() const;
5657

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,16 +1299,6 @@ def _create_ufind(self, optimize_ops):
12991299
ufind.union(op1, op2)
13001300
return ufind
13011301

1302-
def _is_opt_role_op(self, op):
1303-
# NOTE: depend on oprole to find out whether this op is for
1304-
# optimize
1305-
op_maker = core.op_proto_and_checker_maker
1306-
optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
1307-
if op_maker.kOpRoleAttrName() in op.attrs and \
1308-
int(op.attrs[op_maker.kOpRoleAttrName()]) == int(optimize_role):
1309-
return True
1310-
return False
1311-
13121302
def _is_optimizer_op(self, op):
13131303
if "Param" in op.input_names and \
13141304
"LearningRate" in op.input_names:
@@ -1399,7 +1389,10 @@ def _get_optimize_pass(self):
13991389
params_grads = []
14001390
origin_var_dict = self.origin_program.global_block().vars
14011391
for op in block.ops:
1402-
if self._is_opt_role_op(op):
1392+
# NOTE(Yancey1989): we can not use op role to distinguish an optimizer op
1393+
# or not, because all ops in optimizer sub-graph would
1394+
# sign the optimizer op role
1395+
if self._is_optimizer_op(op):
14031396
opt_ops.append(op)
14041397
# HACK(wuyi): if we find grad vars from input of optimize
14051398
# ops, we may get the output of clip op. Use syntax "@GRAD"

0 commit comments

Comments
 (0)