Skip to content

Commit 92313a9

Browse files
committed
update
1 parent d02b17e commit 92313a9

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

python/paddle/fluid/distribute_transpiler.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,12 @@ def get_pserver_program(self, endpoint):
278278
# we don't need to create them when grad arrives.
279279
# change client side var name to origin name by
280280
# removing ".trainer_%d" suffix
281+
282+
suff_idx = v.name.find(".trainer_")
283+
if suff_idx >= 0:
284+
orig_var_name = v.name[:suff_idx]
285+
else:
286+
orig_var_name = v.name
281287
# NOTE: single_trainer_var must be created for multi-trainer
282288
# case to merge grads from multiple trainers
283289
single_trainer_var = \
@@ -287,11 +293,6 @@ def get_pserver_program(self, endpoint):
287293
type=v.type,
288294
dtype=v.dtype,
289295
shape=v.shape)
290-
suff_idx = v.name.find(".trainer_")
291-
if suff_idx >= 0:
292-
orig_var_name = v.name[:suff_idx]
293-
else:
294-
orig_var_name = v.name
295296
if self.trainers > 1:
296297
for trainer_id in xrange(self.trainers):
297298
var = pserver_program.global_block().create_var(

0 commit comments

Comments
 (0)