Skip to content

Commit a0530c3

Browse files
authored
Merge pull request #12123 from jacquesqiao/distribute-transpiler-handle-adam-accumulator
Distribute transpiler handle adam accumulator
2 parents 486121d + 4344e57 commit a0530c3

File tree

1 file changed

+5
-30
lines changed

1 file changed

+5
-30
lines changed

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -377,11 +377,6 @@ def get_pserver_program(self, endpoint):
377377
# append it into the sub program.
378378

379379
global_ops = []
380-
# HACK: optimization global ops only used to scale beta1 and beta2
381-
# replace it with dependency engine.
382-
for op in self.optimize_ops:
383-
if self._is_adam_connected_op(op):
384-
global_ops.append(op)
385380

386381
def __append_optimize_op__(op, block, grad_to_block_id, merged_var,
387382
lr_ops):
@@ -1289,22 +1284,16 @@ def _is_op_connected(self, op1, op2):
12891284
# If one op's input is another op's output or
12901285
# one op's output is another op's input, we say
12911286
# the two operator is connected.
1292-
def _append_inname_remove_beta(varname_list):
1287+
def _append_inname(varname_list):
12931288
op_input_names = []
12941289
for in_name in varname_list:
1295-
# HACK: remove beta1 and beta2 to avoid let all
1296-
# ops connected.
1297-
if in_name.startswith("beta2_pow_acc") or \
1298-
in_name.startswith("beta1_pow_acc"):
1299-
continue
1300-
else:
1301-
op_input_names.append(in_name)
1290+
op_input_names.append(in_name)
13021291
return op_input_names
13031292

1304-
op1_input_names = _append_inname_remove_beta(op1.desc.input_arg_names())
1293+
op1_input_names = _append_inname(op1.desc.input_arg_names())
13051294
op1_output_names = op1.desc.output_arg_names()
13061295

1307-
op2_input_names = _append_inname_remove_beta(op2.desc.input_arg_names())
1296+
op2_input_names = _append_inname(op2.desc.input_arg_names())
13081297
op2_output_names = op2.desc.output_arg_names()
13091298

13101299
if set(op1_output_names) & set(op2_input_names) or \
@@ -1413,7 +1402,7 @@ def _is_opt_role_op(self, op):
14131402

14141403
def _get_optimize_pass(self):
14151404
"""
1416-
Get optimizer operators, paramters and gradients from origin_program
1405+
Get optimizer operators, parameters and gradients from origin_program
14171406
Returns:
14181407
opt_ops (list): optimize operators.
14191408
params_grads (dict): paramter->gradient.
@@ -1436,20 +1425,6 @@ def _get_optimize_pass(self):
14361425
origin_var_dict[param_name],
14371426
origin_var_dict[input_name]
14381427
])
1439-
elif self._is_adam_connected_op(op):
1440-
opt_ops.append(op)
14411428
else:
14421429
pass
14431430
return opt_ops, params_grads
1444-
1445-
def _is_adam_connected_op(self, op):
1446-
"""
1447-
A hack function to determinate whether the input operator
1448-
is connected to optimize operator.
1449-
"""
1450-
if op.type == "scale":
1451-
for in_name in op.input_arg_names:
1452-
if in_name.startswith("beta1_pow_acc") or \
1453-
in_name.startswith("beta2_pow_acc"):
1454-
return True
1455-
return False

0 commit comments

Comments
 (0)