Skip to content

Commit b44b6a4

Browse files
authored
Merge pull request #9798 from typhoonzero/fix_dist_transpiler_bug
Fix dist transpiler bug
2 parents 49d431d + 92313a9 commit b44b6a4

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

python/paddle/fluid/distribute_transpiler.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,11 +278,21 @@ def get_pserver_program(self, endpoint):
278278
# we don't need to create them when grad arrives.
279279
# change client side var name to origin name by
280280
# removing ".trainer_%d" suffix
281+
281282
suff_idx = v.name.find(".trainer_")
282283
if suff_idx >= 0:
283284
orig_var_name = v.name[:suff_idx]
284285
else:
285286
orig_var_name = v.name
287+
# NOTE: single_trainer_var must be created for multi-trainer
288+
# case to merge grads from multiple trainers
289+
single_trainer_var = \
290+
pserver_program.global_block().create_var(
291+
name=orig_var_name,
292+
persistable=True,
293+
type=v.type,
294+
dtype=v.dtype,
295+
shape=v.shape)
286296
if self.trainers > 1:
287297
for trainer_id in xrange(self.trainers):
288298
var = pserver_program.global_block().create_var(
@@ -293,12 +303,6 @@ def get_pserver_program(self, endpoint):
293303
shape=v.shape)
294304
recv_inputs.append(var)
295305
else:
296-
single_trainer_var = pserver_program.global_block().create_var(
297-
name=orig_var_name,
298-
persistable=True,
299-
type=v.type,
300-
dtype=v.dtype,
301-
shape=v.shape)
302306
recv_inputs.append(single_trainer_var)
303307

304308
# step3

0 commit comments

Comments
 (0)