Skip to content

Commit 1460648

Browse files
authored
update parallel.py (#19371)
test=release/1.5
1 parent 6fbd224 commit 1460648

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

python/paddle/fluid/dygraph/parallel.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -188,16 +188,14 @@ def _split_tensors(self, coalesced_grads_and_grad_vars):
188188
from ..layers import nn
189189
for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
190190
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
191-
splited_vars = nn.split(
192-
coalesced_grad, num_or_sections=grad_var_len, dim=0)
193-
reshaped_grad_vars = []
194-
for g_var, g_shape in zip(splited_vars, grad_shapes):
195-
reshaped_grad_vars.append(
196-
nn.reshape(
197-
x=g_var, shape=g_shape, inplace=True))
198-
for origin_g_var, reshaped_g_var in zip(origin_grad_vars,
199-
reshaped_grad_vars):
200-
nn.assign(input=reshaped_g_var, output=origin_g_var)
191+
self._helper.main_program.current_block().append_op(
192+
type='split',
193+
inputs={'X': coalesced_grad},
194+
outputs={'Out': origin_grad_vars},
195+
attrs={'sections': grad_var_len,
196+
'axis': 0})
197+
for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
198+
nn.reshape(x=g_var, shape=g_shape, inplace=True)
201199

202200
def apply_collective_grads(self):
203201
"""

0 commit comments

Comments
 (0)