Skip to content

Commit 5e85384

Browse files
yi.wutyphoonzero
authored andcommitted
fix transpiler error
1 parent 997e8b9 commit 5e85384

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,17 @@ def _init_splited_vars(self, split_method):
186186

187187
param_list = []
188188
grad_list = []
189+
param_grad_set = set()
189190
for p, g in self.params_grads:
190191
# skip parameter marked not trainable
191192
if type(p) == Parameter and p.trainable == False:
192193
continue
193-
param_list.append(p)
194-
grad_list.append(g)
194+
if p.name not in param_grad_set:
195+
param_list.append(p)
196+
param_grad_set.add(p.name)
197+
if g.name not in param_grad_set:
198+
grad_list.append(g)
199+
param_grad_set.add(g.name)
195200

196201
self._update_dist_lookup_table_vars(param_list, grad_list,
197202
self.params_grads)
@@ -802,6 +807,9 @@ def _create_vars_from_blocklist(self,
802807
if not block_map.has_key(varname):
803808
block_map[varname] = []
804809
block_map[varname].append((long(offset), long(size)))
810+
# Do not remove this important debug message:
811+
print("block map: %s" % block_map)
812+
805813
for varname, splited in block_map.iteritems():
806814
orig_var = program.global_block().var(varname)
807815
if len(splited) == 1:

0 commit comments

Comments
 (0)