Skip to content

Commit 34f2818

Browse files
committed
distribute transpiler support async config
1 parent 42a15a4 commit 34f2818

File tree

1 file changed

+23
-20
lines changed

1 file changed

+23
-20
lines changed

python/paddle/fluid/distribute_transpiler.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def get_pserver_program(self, endpoint):
358358
type=v.type,
359359
dtype=v.dtype,
360360
shape=v.shape)
361-
if self.trainer_num > 1:
361+
if self.sync_mode and self.trainer_num > 1:
362362
for trainer_id in xrange(self.trainer_num):
363363
var = pserver_program.global_block().create_var(
364364
name="%s.trainer_%d" % (orig_var_name, trainer_id),
@@ -688,17 +688,6 @@ def _clone_var(block, var, persistable=True):
688688
self.table_name)],
689689
persistable=False)
690690

691-
# create grad vars in pserver program
692-
table_grad_var = self.table_param_grad[1]
693-
table_grad_list = [
694-
pserver_program.global_block().create_var(
695-
name="%s.trainer_%d.pserver_%d" %
696-
(table_grad_var.name, index, pserver_index),
697-
type=table_grad_var.type,
698-
shape=table_grad_var.shape,
699-
dtype=table_grad_var.dtype) for index in range(self.trainer_num)
700-
]
701-
702691
# create table optimize block in pserver program
703692
table_opt_op = [
704693
op for op in self.optimize_ops
@@ -708,11 +697,24 @@ def _clone_var(block, var, persistable=True):
708697
# only support sgd now
709698
assert table_opt_op.type == "sgd"
710699

711-
# append sum op for table_grad_list
712-
table_opt_block.append_op(
713-
type="sum",
714-
inputs={"X": table_grad_list},
715-
outputs={"Out": [grad_var]})
700+
if self.sync_mode:
701+
# create grad vars in pserver program
702+
table_grad_var = self.table_param_grad[1]
703+
table_grad_list = [
704+
pserver_program.global_block().create_var(
705+
name="%s.trainer_%d.pserver_%d" %
706+
(table_grad_var.name, index, pserver_index),
707+
type=table_grad_var.type,
708+
shape=table_grad_var.shape,
709+
dtype=table_grad_var.dtype)
710+
for index in range(self.trainer_num)
711+
]
712+
713+
# append sum op for table_grad_list
714+
table_opt_block.append_op(
715+
type="sum",
716+
inputs={"X": table_grad_list},
717+
outputs={"Out": [grad_var]})
716718

717719
lr_var = pserver_program.global_block().vars[table_opt_op.input(
718720
"LearningRate")[0]]
@@ -751,7 +753,7 @@ def _create_vars_from_blocklist(self,
751753
for varname, splited in block_map.iteritems():
752754
orig_var = program.global_block().var(varname)
753755
if len(splited) == 1:
754-
if add_trainer_suffix:
756+
if self.sync_mode and add_trainer_suffix:
755757
new_var_name = "%s.trainer_%d" % \
756758
(orig_var.name, self.trainer_id)
757759
program.global_block().rename_var(varname, new_var_name)
@@ -775,7 +777,7 @@ def _create_vars_from_blocklist(self,
775777
if len(orig_shape) >= 2:
776778
splited_shape.extend(orig_shape[1:])
777779
new_var_name = ""
778-
if add_trainer_suffix:
780+
if self.sync_mode and add_trainer_suffix:
779781
new_var_name = "%s.block%d.trainer_%d" % \
780782
(varname, i, self.trainer_id)
781783
else:
@@ -907,7 +909,7 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
907909
pserver_block.vars[self._orig_varname(grad_block.name)]
908910
grad_to_block_id.append(merged_var.name + ":" + str(
909911
optimize_block.idx))
910-
if self.trainer_num > 1:
912+
if self.sync_mode and self.trainer_num > 1:
911913
vars2merge = []
912914
for i in xrange(self.trainer_num):
913915
per_trainer_name = "%s.trainer_%d" % \
@@ -925,6 +927,7 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
925927
inputs={"X": merged_var},
926928
outputs={"Out": merged_var},
927929
attrs={"scale": 1.0 / float(self.trainer_num)})
930+
928931
new_inputs[key] = merged_var
929932
elif key == "Param":
930933
# param is already created on global program

0 commit comments

Comments
 (0)