Skip to content

Commit b7c683b

Browse files
authored
Merge pull request #11326 from jacquesqiao/fix-distribute_transpiler
fix distribute_transpiler
2 parents d896134 + bf03a20 commit b7c683b

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def _update_dist_lookup_table_vars(self, param_list, grad_list,
177177
dtype=table_grad_var.dtype)
178178
for index in range(len(self.pserver_endpoints))
179179
]
180+
return param_list, grad_list
180181

181182
def _init_splited_vars(self, slice_var_up):
182183
# update these mappings for further transpile:
@@ -199,8 +200,8 @@ def _init_splited_vars(self, slice_var_up):
199200
grad_list.append(g)
200201
param_grad_set.add(g.name)
201202

202-
self._update_dist_lookup_table_vars(param_list, grad_list,
203-
self.params_grads)
203+
param_list, grad_list = self._update_dist_lookup_table_vars(
204+
param_list, grad_list, self.params_grads)
204205

205206
if slice_var_up:
206207
# when we slice var up into blocks, we will slice the var according to

0 commit comments

Comments
 (0)