Skip to content

Commit f3eafec

Browse files
committed
fix pserver weight decay multi inputs test=develop
1 parent 5b7a9dd commit f3eafec

File tree

1 file changed

+36
-16
lines changed

1 file changed

+36
-16
lines changed

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1706,13 +1706,27 @@ def _get_param_block(opt_op):
17061706
outputs=outputs,
17071707
attrs=opt_op.all_attrs())
17081708

1709-
def _is_splited_grad_var(self, var, var_dict):
1709+
def _get_pserver_grad_param_var(self, var, var_dict):
1710+
"""
1711+
Return pserver side grad/param variable, return None
1712+
if the variable is not grad/param, e.g.
1713+
1714+
1715+
a@GRAD -> a@GRAD (a is not splited)
1716+
fc_0.w_0 -> fc_0.w_0.block_0
1717+
fc_0.w_0 -> fc_0.w_0 (weight is not splited)
1718+
_generated_var_123 -> None
1719+
"""
17101720
grad_block = None
17111721
for _, g in six.iteritems(var_dict):
17121722
if self._orig_varname(g.name) == self._orig_varname(var.name):
1723+
# skip per trainer vars
17131724
if g.name.find(".trainer_") == -1:
1714-
grad_block = g
1715-
break
1725+
# only param or grads have splited blocks
1726+
if self._orig_varname(g.name) in self.grad_name_to_param_name or\
1727+
self._orig_varname(g.name) in self.param_name_to_grad_name:
1728+
grad_block = g
1729+
break
17161730
return grad_block
17171731

17181732
def _clone_lr_op(self, program, block, op):
@@ -1745,32 +1759,38 @@ def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
17451759
for key, varlist in six.iteritems(inputs):
17461760
if not isinstance(varlist, list):
17471761
varlist = [varlist]
1748-
for var in varlist:
1749-
# for ops like clipping and weight decay, get the splited var
1762+
for i in range(len(varlist)):
1763+
var = varlist[i]
1764+
# for ops like clipping and weight decay, get the splited var (xxx.block0)
17501765
# for inputs/outputs
1751-
grad_block = self._is_splited_grad_var(
1766+
grad_block = self._get_pserver_grad_param_var(
17521767
var, program.global_block().vars)
17531768
if grad_block:
1754-
inputs[key] = grad_block
1769+
varlist[i] = grad_block
17551770
elif var.name not in program.global_block().vars:
1756-
program.global_block().create_var(
1757-
name=var.name,
1758-
persistable=var.persistable,
1759-
dtype=var.dtype,
1760-
shape=var.shape)
1771+
tmpvar = program.global_block()._clone_variable(var)
1772+
varlist[i] = tmpvar
1773+
else:
1774+
varlist[i] = program.global_block().vars[var.name]
1775+
inputs[key] = varlist
17611776

17621777
outputs = self._get_output_map_from_op(
17631778
self.origin_program.global_block().vars, opt_op)
17641779
for key, varlist in six.iteritems(outputs):
17651780
if not isinstance(varlist, list):
17661781
varlist = [varlist]
1767-
for var in varlist:
1768-
grad_block = self._is_splited_grad_var(
1782+
for i in range(len(varlist)):
1783+
var = varlist[i]
1784+
grad_block = self._get_pserver_grad_param_var(
17691785
var, program.global_block().vars)
17701786
if grad_block:
1771-
outputs[key] = grad_block
1787+
varlist[i] = grad_block
17721788
elif var.name not in program.global_block().vars:
1773-
program.global_block()._clone_variable(var)
1789+
tmpvar = program.global_block()._clone_variable(var)
1790+
varlist[i] = tmpvar
1791+
else:
1792+
varlist[i] = program.global_block().vars[var.name]
1793+
outputs[key] = varlist
17741794

17751795
return optimize_block.append_op(
17761796
type=opt_op.type,

0 commit comments

Comments
 (0)