Skip to content

Commit 4cd44c0

Browse files
committed
fix
test=develop
1 parent 38cf553 commit 4cd44c0

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
5050
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
5151
)
52+
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
5253
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
5354
DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist
5455
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
@@ -1717,8 +1718,10 @@ def _get_lr_ops(self):
17171718
lr_ops = []
17181719
block = self.origin_program.global_block()
17191720
for op in block.ops:
1720-
if int(op.attr(RPC_OP_ROLE_ATTR_NAME)) | int(
1721-
LR_SCHED_OP_ROLE_ATTR_VALUE) > 0:
1721+
role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME))
1722+
if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \
1723+
role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \
1724+
int(OPT_OP_ROLE_ATTR_VALUE):
17221725
lr_ops.append(op)
17231726
log("append lr op: ", op.type)
17241727
return lr_ops

0 commit comments

Comments
 (0)