@@ -71,7 +71,7 @@ def same_or_split_var(p_name, var_name):
71
71
return p_name == var_name or p_name .startswith (var_name + ".block" )
72
72
73
73
74
- def split_variable (var_list , service_count , min_block_size = 8192 ):
74
+ def slice_variable (var_list , slice_count , min_block_size = 8192 ):
75
75
"""
76
76
We may need to split dense tensor to one or more blocks and put
77
77
them equally onto parameter server. One block is a sub-tensor
@@ -83,21 +83,21 @@ def split_variable(var_list, service_count, min_block_size=8192):
83
83
84
84
Args:
85
85
var_list (list): List of variables.
86
- service_count (int): Numel of pserver services. A pserver may have two
87
- or more listening ports .
86
+ slice_count (int): Numel of count that variables will be sliced, which
87
+ could be the pserver services' count .
88
88
min_block_size (int): Minimum splitted block size.
89
89
Returns:
90
90
blocks (list[(varname, block_id, current_block_size)]): A list
91
91
of VarBlocks. Each VarBlock specifies a shard of the var.
92
92
"""
93
93
blocks = []
94
94
for var in var_list :
95
- split_count = service_count
95
+ split_count = slice_count
96
96
var_numel = reduce (lambda x , y : x * y , var .shape )
97
97
max_pserver_count = int (math .floor (var_numel / float (min_block_size )))
98
98
if max_pserver_count == 0 :
99
99
max_pserver_count = 1
100
- if max_pserver_count < service_count :
100
+ if max_pserver_count < slice_count :
101
101
split_count = max_pserver_count
102
102
block_size = int (math .ceil (var_numel / float (split_count )))
103
103
@@ -178,7 +178,7 @@ def _update_dist_lookup_table_vars(self, param_list, grad_list,
178
178
for index in range (len (self .pserver_endpoints ))
179
179
]
180
180
181
- def _init_splited_vars (self , split_method , align_var_to_block = True ):
181
+ def _init_splited_vars (self , slice_var_up ):
182
182
# update these mappings for further transpile:
183
183
# 1. param_var_mapping: param var name -> [splited params vars]
184
184
# 2. grad_var_mapping: grad var name -> [splited grads vars]
@@ -197,16 +197,19 @@ def _init_splited_vars(self, split_method, align_var_to_block=True):
197
197
self ._update_dist_lookup_table_vars (param_list , grad_list ,
198
198
self .params_grads )
199
199
200
- if align_var_to_block :
201
- grad_blocks = split_variable (grad_list , len (self .pserver_endpoints ))
202
- param_blocks = split_variable (param_list ,
200
+ if slice_var_up :
201
+ # when we slice var up into blocks, we will slice the var according to
202
+ # pserver services' count. A pserver may have two or more listening ports.
203
+ grad_blocks = slice_variable (grad_list , len (self .pserver_endpoints ))
204
+ param_blocks = slice_variable (param_list ,
203
205
len (self .pserver_endpoints ))
204
206
else :
205
- # when we do NOT align var to block , we will always split params
207
+ # when we do NOT slice var up into blocks , we will always slice params
206
208
# grads into one block.
207
- grad_blocks = split_variable (grad_list , 1 )
208
- param_blocks = split_variable (param_list , 1 )
209
+ grad_blocks = slice_variable (grad_list , 1 )
210
+ param_blocks = slice_variable (param_list , 1 )
209
211
assert (len (grad_blocks ) == len (param_blocks ))
212
+
210
213
# origin_varname -> [splited_var]
211
214
self .param_var_mapping = self ._create_vars_from_blocklist (
212
215
self .origin_program , param_blocks )
@@ -237,7 +240,7 @@ def transpile(self,
237
240
program = None ,
238
241
pservers = "127.0.0.1:6174" ,
239
242
trainers = 1 ,
240
- align_var_to_block = True ,
243
+ slice_var_up = True ,
241
244
split_method = RoundRobin ,
242
245
sync_mode = True ):
243
246
"""
@@ -271,7 +274,7 @@ def transpile(self,
271
274
self .has_distributed_lookup_table = self ._has_distributed_lookup_table ()
272
275
273
276
# split and create vars, then put splited vars in dicts for later use.
274
- self ._init_splited_vars (split_method , align_var_to_block )
277
+ self ._init_splited_vars (slice_var_up )
275
278
276
279
# step 3.1: insert send op to send gradient vars to parameter servers
277
280
ps_dispatcher .reset ()
@@ -283,13 +286,13 @@ def transpile(self,
283
286
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
284
287
# shuffle the map will avoid the uneven distribution above
285
288
grad_var_mapping_items = self .grad_var_mapping .items ()
286
- if not align_var_to_block :
289
+ if not slice_var_up :
287
290
np .random .shuffle (grad_var_mapping_items )
288
291
289
292
for orig_varname , splited_vars in grad_var_mapping_items :
290
293
eplist = ps_dispatcher .dispatch (splited_vars )
291
294
292
- if not align_var_to_block :
295
+ if not slice_var_up :
293
296
assert (len (splited_vars ) == 1 )
294
297
295
298
if len (splited_vars ) == 1 :
0 commit comments