@@ -188,16 +188,14 @@ def _split_tensors(self, coalesced_grads_and_grad_vars):
188
188
from ..layers import nn
189
189
for coalesced_grad , origin_grad_vars , grad_shapes in coalesced_grads_and_grad_vars :
190
190
grad_var_len = [np .prod (g_shape ) for g_shape in grad_shapes ]
191
- splited_vars = nn .split (
192
- coalesced_grad , num_or_sections = grad_var_len , dim = 0 )
193
- reshaped_grad_vars = []
194
- for g_var , g_shape in zip (splited_vars , grad_shapes ):
195
- reshaped_grad_vars .append (
196
- nn .reshape (
197
- x = g_var , shape = g_shape , inplace = True ))
198
- for origin_g_var , reshaped_g_var in zip (origin_grad_vars ,
199
- reshaped_grad_vars ):
200
- nn .assign (input = reshaped_g_var , output = origin_g_var )
191
+ self ._helper .main_program .current_block ().append_op (
192
+ type = 'split' ,
193
+ inputs = {'X' : coalesced_grad },
194
+ outputs = {'Out' : origin_grad_vars },
195
+ attrs = {'sections' : grad_var_len ,
196
+ 'axis' : 0 })
197
+ for g_var , g_shape in zip (origin_grad_vars , grad_shapes ):
198
+ nn .reshape (x = g_var , shape = g_shape , inplace = True )
201
199
202
200
def apply_collective_grads (self ):
203
201
"""
0 commit comments