@@ -1706,13 +1706,27 @@ def _get_param_block(opt_op):
1706
1706
outputs = outputs ,
1707
1707
attrs = opt_op .all_attrs ())
1708
1708
1709
- def _is_splited_grad_var (self , var , var_dict ):
1709
+ def _get_pserver_grad_param_var (self , var , var_dict ):
1710
+ """
1711
+ Return pserver side grad/param variable, return None
1712
+ if the variable is not grad/param, e.g.
1713
+
1714
+
1715
+ a@GRAD -> a@GRAD (a is not splited)
1716
+ fc_0.w_0 -> fc_0.w_0.block_0
1717
+ fc_0.w_0 -> fc_0.w_0 (weight is not splited)
1718
+ _generated_var_123 -> None
1719
+ """
1710
1720
grad_block = None
1711
1721
for _ , g in six .iteritems (var_dict ):
1712
1722
if self ._orig_varname (g .name ) == self ._orig_varname (var .name ):
1723
+ # skip per trainer vars
1713
1724
if g .name .find (".trainer_" ) == - 1 :
1714
- grad_block = g
1715
- break
1725
+ # only param or grads have splited blocks
1726
+ if self ._orig_varname (g .name ) in self .grad_name_to_param_name or \
1727
+ self ._orig_varname (g .name ) in self .param_name_to_grad_name :
1728
+ grad_block = g
1729
+ break
1716
1730
return grad_block
1717
1731
1718
1732
def _clone_lr_op (self , program , block , op ):
@@ -1745,32 +1759,38 @@ def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
1745
1759
for key , varlist in six .iteritems (inputs ):
1746
1760
if not isinstance (varlist , list ):
1747
1761
varlist = [varlist ]
1748
- for var in varlist :
1749
- # for ops like clipping and weight decay, get the splited var
1762
+ for i in range (len (varlist )):
1763
+ var = varlist [i ]
1764
+ # for ops like clipping and weight decay, get the splited var (xxx.block0)
1750
1765
# for inputs/outputs
1751
- grad_block = self ._is_splited_grad_var (
1766
+ grad_block = self ._get_pserver_grad_param_var (
1752
1767
var , program .global_block ().vars )
1753
1768
if grad_block :
1754
- inputs [ key ] = grad_block
1769
+ varlist [ i ] = grad_block
1755
1770
elif var .name not in program .global_block ().vars :
1756
- program .global_block ().create_var (
1757
- name = var . name ,
1758
- persistable = var . persistable ,
1759
- dtype = var .dtype ,
1760
- shape = var . shape )
1771
+ tmpvar = program .global_block ()._clone_variable ( var )
1772
+ varlist [ i ] = tmpvar
1773
+ else :
1774
+ varlist [ i ] = program . global_block (). vars [ var .name ]
1775
+ inputs [ key ] = varlist
1761
1776
1762
1777
outputs = self ._get_output_map_from_op (
1763
1778
self .origin_program .global_block ().vars , opt_op )
1764
1779
for key , varlist in six .iteritems (outputs ):
1765
1780
if not isinstance (varlist , list ):
1766
1781
varlist = [varlist ]
1767
- for var in varlist :
1768
- grad_block = self ._is_splited_grad_var (
1782
+ for i in range (len (varlist )):
1783
+ var = varlist [i ]
1784
+ grad_block = self ._get_pserver_grad_param_var (
1769
1785
var , program .global_block ().vars )
1770
1786
if grad_block :
1771
- outputs [ key ] = grad_block
1787
+ varlist [ i ] = grad_block
1772
1788
elif var .name not in program .global_block ().vars :
1773
- program .global_block ()._clone_variable (var )
1789
+ tmpvar = program .global_block ()._clone_variable (var )
1790
+ varlist [i ] = tmpvar
1791
+ else :
1792
+ varlist [i ] = program .global_block ().vars [var .name ]
1793
+ outputs [key ] = varlist
1774
1794
1775
1795
return optimize_block .append_op (
1776
1796
type = opt_op .type ,
0 commit comments