Skip to content

Commit 772cdfe

Browse files
committed
fix single pserver error
1 parent ef802ce commit 772cdfe

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

python/paddle/fluid/distribute_transpiler.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -276,20 +276,25 @@ def get_pserver_program(self, endpoint):
276276
suff_idx = v.name.find(".trainer_")
277277
if suff_idx >= 0:
278278
orig_var_name = v.name[:suff_idx]
279-
pserver_program.global_block().create_var(
279+
else:
280+
orig_var_name = v.name
281+
single_trainer_var = pserver_program.global_block().create_var(
280282
name=orig_var_name,
281283
persistable=True,
282284
type=v.type,
283285
dtype=v.dtype,
284286
shape=v.shape)
285-
for trainer_id in xrange(self.trainers):
286-
var = pserver_program.global_block().create_var(
287-
name="%s.trainer_%d" % (orig_var_name, trainer_id),
288-
persistable=False,
289-
type=v.type,
290-
dtype=v.dtype,
291-
shape=v.shape)
292-
recv_inputs.append(var)
287+
if self.trainers > 1:
288+
for trainer_id in xrange(self.trainers):
289+
var = pserver_program.global_block().create_var(
290+
name="%s.trainer_%d" % (orig_var_name, trainer_id),
291+
persistable=False,
292+
type=v.type,
293+
dtype=v.dtype,
294+
shape=v.shape)
295+
recv_inputs.append(var)
296+
else:
297+
recv_inputs.append(single_trainer_var)
293298

294299
# step3
295300
optimize_block = pserver_program.create_block(0)
@@ -511,8 +516,11 @@ def _clone_var(self, block, var):
511516

512517
def _append_split_op(self, program, gradblocks):
513518
# Split variables that need to be split and append respective ops
519+
add_suffix = False
520+
if self.trainers > 1:
521+
add_suffix = True
514522
var_mapping = self._create_vars_from_blocklist(
515-
program, gradblocks, add_trainer_suffix=True)
523+
program, gradblocks, add_trainer_suffix=add_suffix)
516524
for varname, splited_vars in var_mapping.iteritems():
517525
# variable that don't need to split have empty splited_vars
518526
if len(splited_vars) <= 1:

0 commit comments

Comments
 (0)