Skip to content

Commit e9737d6

Browse files
committed
add a private function to find adam opt pass
1 parent da960ad commit e9737d6

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

python/paddle/fluid/distribute_transpiler.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -401,11 +401,8 @@ def get_pserver_program(self, endpoint):
401401
# HACK: optimization global ops only used to scale beta1 and beta2
402402
# replace it with dependency engine.
403403
for op in self.optimize_ops:
404-
if op.type == "scale":
405-
for in_name in op.input_arg_names:
406-
if in_name.startswith("beta1_pow_acc") or \
407-
in_name.startswith("beta2_pow_acc"):
408-
global_ops.append(op)
404+
if self._is_adam_connected_op(op):
405+
global_ops.append(op)
409406

410407
def __append_optimize_op__(op, block, grad_to_block_id):
411408
if self._is_opt_op(op):
@@ -1152,13 +1149,20 @@ def _get_optimize_pass(self):
11521149
op.input("Param")[0]),
11531150
self.origin_program.global_block().var(
11541151
op.input("Grad")[0])))
1155-
elif op.type == "scale":
1156-
# for adam optimize op
1157-
for in_name in op.input_arg_names:
1158-
if in_name.startswith("beta1_pow_acc") or \
1159-
in_name.startswith("beta2_pow_acc"):
1160-
opt_ops.append(op)
1161-
break
1152+
elif self._is_adam_connected_op(op):
1153+
opt_ops.append(op)
11621154
else:
11631155
pass
11641156
return opt_ops, params_grads
1157+
1158+
def _is_adam_connected_op(self, op):
1159+
"""
1160+
A hack function to determinate whether the input operator
1161+
is connected to optimize operator.
1162+
"""
1163+
if op.type == "scale":
1164+
for in_name in op.input_arg_names:
1165+
if in_name.startswith("beta1_pow_acc") or \
1166+
in_name.startswith("beta2_pow_acc"):
1167+
return True
1168+
return False

0 commit comments

Comments
 (0)