Skip to content

Commit 9f0dcfd

Browse files
authored
Merge pull request #11155 from typhoonzero/fix_transpiler_merged_bug
Fix single pserver transpile error after merging
2 parents 87a5590 + 9d3114c commit 9f0dcfd

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,17 @@ def _init_splited_vars(self, slice_var_up):
187187

188188
param_list = []
189189
grad_list = []
190+
param_grad_set = set()
190191
for p, g in self.params_grads:
191192
# skip parameter marked not trainable
192193
if type(p) == Parameter and p.trainable == False:
193194
continue
194-
param_list.append(p)
195-
grad_list.append(g)
195+
if p.name not in param_grad_set:
196+
param_list.append(p)
197+
param_grad_set.add(p.name)
198+
if g.name not in param_grad_set:
199+
grad_list.append(g)
200+
param_grad_set.add(g.name)
196201

197202
self._update_dist_lookup_table_vars(param_list, grad_list,
198203
self.params_grads)
@@ -829,6 +834,9 @@ def _create_vars_from_blocklist(self,
829834
if not block_map.has_key(varname):
830835
block_map[varname] = []
831836
block_map[varname].append((long(offset), long(size)))
837+
# Do not remove this important debug message:
838+
print("block map: %s" % block_map)
839+
832840
for varname, splited in block_map.iteritems():
833841
orig_var = program.global_block().var(varname)
834842
if len(splited) == 1:

0 commit comments

Comments
 (0)