Skip to content

Commit 21c7db4

Browse files
author
weixing02
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into build_manually
Merge branch develop
2 parents 9bec0d2 + e150591 commit 21c7db4

File tree

1 file changed

+27
-16
lines changed

1 file changed

+27
-16
lines changed

python/paddle/fluid/distribute_transpiler.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def split_dense_variable(var_list,
102102
the parameter server side can gain better performance. By default
103103
minimum block size is 1024. The max block size is used to prevent
104104
very large blocks that may cause send error.
105+
:return: A list of VarBlocks. Each VarBlock specifies a shard of
106+
the var.
105107
"""
106108
blocks = []
107109
for var in var_list:
@@ -192,22 +194,24 @@ def transpile(self,
192194
self.trainer_id = trainer_id
193195
pserver_endpoints = pservers.split(",")
194196

195-
# step1
197+
# step1: For large parameters and gradients, split them into smaller
198+
# blocks.
196199
param_list = [pg[0] for pg in params_grads]
197200
grad_list = [pg[1] for pg in params_grads]
198201
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints))
199202
param_blocks = split_dense_variable(param_list, len(pserver_endpoints))
200-
# step2
203+
# step2: Create new vars for the parameters and gradients blocks and
204+
# add ops to do the split.
201205
grad_var_mapping = self._append_split_op(program, grad_blocks)
202-
# step3
206+
param_var_mapping = self._create_vars_from_blocklist(program,
207+
param_blocks)
208+
# step3: Add gradients as send op inputs and parameters as send
209+
# op outputs.
203210
send_inputs = []
204211
send_outputs = []
205212
for b in grad_blocks: # append by order
206213
varname, block_id, _ = b.split(":")
207214
send_inputs.append(grad_var_mapping[varname][int(block_id)])
208-
209-
param_var_mapping = self._create_vars_from_blocklist(program,
210-
param_blocks)
211215
for b in param_blocks:
212216
varname, block_id, _ = b.split(":")
213217
send_outputs.append(param_var_mapping[varname][int(block_id)])
@@ -237,7 +241,7 @@ def transpile(self,
237241
"RPCClient": rpc_client_var},
238242
attrs={"endpoints": pserver_endpoints,
239243
"epmap": eplist})
240-
# step4
244+
# step4: Concat the parameters splits together after recv.
241245
for varname, splited_var in param_var_mapping.iteritems():
242246
if len(splited_var) <= 1:
243247
continue
@@ -258,13 +262,14 @@ def get_trainer_program(self):
258262
def get_pserver_program(self, endpoint):
259263
"""
260264
Get pserver side program using the endpoint.
265+
TODO(panyx0718): Revisit this assumption. what if #blocks > #pservers.
261266
NOTE: assume blocks of the same variable is not distributed
262267
on the same pserver, only change param/grad varnames for
263268
trainers to fetch.
264269
"""
265270
# step1
266271
pserver_program = Program()
267-
# step2
272+
# step2: Create vars to receive vars at parameter servers.
268273
recv_inputs = []
269274
for v in self.param_grad_ep_mapping[endpoint]["params"]:
270275
self._clone_var(pserver_program.global_block(), v)
@@ -278,12 +283,6 @@ def get_pserver_program(self, endpoint):
278283
orig_var_name = v.name[:suff_idx]
279284
else:
280285
orig_var_name = v.name
281-
single_trainer_var = pserver_program.global_block().create_var(
282-
name=orig_var_name,
283-
persistable=True,
284-
type=v.type,
285-
dtype=v.dtype,
286-
shape=v.shape)
287286
if self.trainers > 1:
288287
for trainer_id in xrange(self.trainers):
289288
var = pserver_program.global_block().create_var(
@@ -294,6 +293,12 @@ def get_pserver_program(self, endpoint):
294293
shape=v.shape)
295294
recv_inputs.append(var)
296295
else:
296+
single_trainer_var = pserver_program.global_block().create_var(
297+
name=orig_var_name,
298+
persistable=True,
299+
type=v.type,
300+
dtype=v.dtype,
301+
shape=v.shape)
297302
recv_inputs.append(single_trainer_var)
298303

299304
# step3
@@ -344,7 +349,7 @@ def __append_optimize_op__(op, block):
344349
self._append_pserver_non_opt_ops(block, op)
345350

346351
append_block = optimize_block
347-
# append lr decay ops to the child block if exits
352+
# append lr decay ops to the child block if exists
348353
lr_ops = self._get_lr_ops()
349354
if len(lr_ops) > 0:
350355
for _, op in enumerate(lr_ops):
@@ -447,8 +452,10 @@ def _create_vars_from_blocklist(self,
447452
block_list,
448453
add_trainer_suffix=False):
449454
"""
455+
Create vars for each split.
450456
NOTE: only grads need to be named for different trainers, use
451457
add_trainer_suffix to rename the grad vars.
458+
:return: A dict mapping from original var name to each var split.
452459
"""
453460
block_map = dict()
454461
var_mapping = dict()
@@ -615,6 +622,7 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
615622
type="sum",
616623
inputs={"X": vars2merge},
617624
outputs={"Out": merged_var})
625+
# TODO(panyx0718): What if it's SELECTED_ROWS.
618626
if not merged_var.type == core.VarDesc.VarType.SELECTED_ROWS:
619627
optimize_block.append_op(
620628
type="scale",
@@ -638,7 +646,7 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
638646
shape=param_block.shape)
639647
new_inputs[key] = tmpvar
640648
elif key == "LearningRate":
641-
# leraning rate variable has already be created by non-optimize op,
649+
# learning rate variable has already be created by non-optimize op,
642650
# don't create it once again.
643651
lr_varname = opt_op.input(key)[0]
644652
if pserver_block.vars.has_key(lr_varname):
@@ -773,6 +781,7 @@ def _is_opt_op_on_pserver(self, endpoint, op):
773781
return False
774782

775783
def _get_input_map_from_op(self, varmap, op):
784+
"""Returns a dict from op input name to the vars in varmap."""
776785
iomap = dict()
777786
for key in op.input_names:
778787
vars = []
@@ -785,6 +794,7 @@ def _get_input_map_from_op(self, varmap, op):
785794
return iomap
786795

787796
def _get_output_map_from_op(self, varmap, op):
797+
"""Returns a dict from op output name to the vars in varmap."""
788798
iomap = dict()
789799
for key in op.output_names:
790800
vars = []
@@ -812,6 +822,7 @@ def _get_lr_ops(self):
812822
find_ops.append(op)
813823
# make a union find struct by the ops in default_main_program
814824
ufind = UnionFind(block.ops)
825+
815826
for op1 in block.ops:
816827
for op2 in block.ops:
817828
# NOTE: we need to skip all optimize ops, since it is connected

0 commit comments

Comments
 (0)