Skip to content

Commit b33ea7b

Browse files
committed
1. change the variable name from align_var_to_block to slice_var_up
2. replace split_method with slice_var_up in func init_splited_variables
1 parent 9d92dce commit b33ea7b

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

python/paddle/fluid/tests/unittests/test_simple_dist_transpiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _transpiler_instance(self):
7272
program=main,
7373
pservers=self.pserver_eps,
7474
trainers=self.trainers,
75-
align_var_to_block=False)
75+
slice_var_up=False)
7676
return t
7777

7878

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def same_or_split_var(p_name, var_name):
7171
return p_name == var_name or p_name.startswith(var_name + ".block")
7272

7373

74-
def split_variable(var_list, service_count, min_block_size=8192):
74+
def slice_variable(var_list, slice_count, min_block_size=8192):
7575
"""
7676
We may need to split dense tensor to one or more blocks and put
7777
them equally onto parameter server. One block is a sub-tensor
@@ -83,21 +83,21 @@ def split_variable(var_list, service_count, min_block_size=8192):
8383
8484
Args:
8585
var_list (list): List of variables.
86-
service_count (int): Numel of pserver services. A pserver may have two
87-
or more listening ports.
86+
slice_count (int): Numel of count that variables will be sliced, which
87+
could be the pserver services' count.
8888
min_block_size (int): Minimum splitted block size.
8989
Returns:
9090
blocks (list[(varname, block_id, current_block_size)]): A list
9191
of VarBlocks. Each VarBlock specifies a shard of the var.
9292
"""
9393
blocks = []
9494
for var in var_list:
95-
split_count = service_count
95+
split_count = slice_count
9696
var_numel = reduce(lambda x, y: x * y, var.shape)
9797
max_pserver_count = int(math.floor(var_numel / float(min_block_size)))
9898
if max_pserver_count == 0:
9999
max_pserver_count = 1
100-
if max_pserver_count < service_count:
100+
if max_pserver_count < slice_count:
101101
split_count = max_pserver_count
102102
block_size = int(math.ceil(var_numel / float(split_count)))
103103

@@ -178,7 +178,7 @@ def _update_dist_lookup_table_vars(self, param_list, grad_list,
178178
for index in range(len(self.pserver_endpoints))
179179
]
180180

181-
def _init_splited_vars(self, split_method, align_var_to_block=True):
181+
def _init_splited_vars(self, slice_var_up):
182182
# update these mappings for further transpile:
183183
# 1. param_var_mapping: param var name -> [splited params vars]
184184
# 2. grad_var_mapping: grad var name -> [splited grads vars]
@@ -197,16 +197,19 @@ def _init_splited_vars(self, split_method, align_var_to_block=True):
197197
self._update_dist_lookup_table_vars(param_list, grad_list,
198198
self.params_grads)
199199

200-
if align_var_to_block:
201-
grad_blocks = split_variable(grad_list, len(self.pserver_endpoints))
202-
param_blocks = split_variable(param_list,
200+
if slice_var_up:
201+
# when we slice var up into blocks, we will slice the var according to
202+
# pserver services' count. A pserver may have two or more listening ports.
203+
grad_blocks = slice_variable(grad_list, len(self.pserver_endpoints))
204+
param_blocks = slice_variable(param_list,
203205
len(self.pserver_endpoints))
204206
else:
205-
# when we do NOT align var to block, we will always split params
207+
# when we do NOT slice var up into blocks, we will always slice params
206208
# grads into one block.
207-
grad_blocks = split_variable(grad_list, 1)
208-
param_blocks = split_variable(param_list, 1)
209+
grad_blocks = slice_variable(grad_list, 1)
210+
param_blocks = slice_variable(param_list, 1)
209211
assert (len(grad_blocks) == len(param_blocks))
212+
210213
# origin_varname -> [splited_var]
211214
self.param_var_mapping = self._create_vars_from_blocklist(
212215
self.origin_program, param_blocks)
@@ -237,7 +240,7 @@ def transpile(self,
237240
program=None,
238241
pservers="127.0.0.1:6174",
239242
trainers=1,
240-
align_var_to_block=True,
243+
slice_var_up=True,
241244
split_method=RoundRobin,
242245
sync_mode=True):
243246
"""
@@ -271,7 +274,7 @@ def transpile(self,
271274
self.has_distributed_lookup_table = self._has_distributed_lookup_table()
272275

273276
# split and create vars, then put splited vars in dicts for later use.
274-
self._init_splited_vars(split_method, align_var_to_block)
277+
self._init_splited_vars(slice_var_up)
275278

276279
# step 3.1: insert send op to send gradient vars to parameter servers
277280
ps_dispatcher.reset()
@@ -283,13 +286,13 @@ def transpile(self,
283286
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
284287
# shuffle the map will avoid the uneven distribution above
285288
grad_var_mapping_items = self.grad_var_mapping.items()
286-
if not align_var_to_block:
289+
if not slice_var_up:
287290
np.random.shuffle(grad_var_mapping_items)
288291

289292
for orig_varname, splited_vars in grad_var_mapping_items:
290293
eplist = ps_dispatcher.dispatch(splited_vars)
291294

292-
if not align_var_to_block:
295+
if not slice_var_up:
293296
assert (len(splited_vars) == 1)
294297

295298
if len(splited_vars) == 1:

0 commit comments

Comments
 (0)