@@ -102,6 +102,8 @@ def split_dense_variable(var_list,
102
102
the parameter server side can gain better performance. By default
103
103
minimum block size is 1024. The max block size is used to prevent
104
104
very large blocks that may cause send error.
105
+ :return: A list of VarBlocks. Each VarBlock specifies a shard of
106
+ the var.
105
107
"""
106
108
blocks = []
107
109
for var in var_list :
@@ -192,22 +194,24 @@ def transpile(self,
192
194
self .trainer_id = trainer_id
193
195
pserver_endpoints = pservers .split ("," )
194
196
195
- # step1
197
+ # step1: For large parameters and gradients, split them into smaller
198
+ # blocks.
196
199
param_list = [pg [0 ] for pg in params_grads ]
197
200
grad_list = [pg [1 ] for pg in params_grads ]
198
201
grad_blocks = split_dense_variable (grad_list , len (pserver_endpoints ))
199
202
param_blocks = split_dense_variable (param_list , len (pserver_endpoints ))
200
- # step2
203
+ # step2: Create new vars for the parameters and gradients blocks and
204
+ # add ops to do the split.
201
205
grad_var_mapping = self ._append_split_op (program , grad_blocks )
202
- # step3
206
+ param_var_mapping = self ._create_vars_from_blocklist (program ,
207
+ param_blocks )
208
+ # step3: Add gradients as send op inputs and parameters as send
209
+ # op outputs.
203
210
send_inputs = []
204
211
send_outputs = []
205
212
for b in grad_blocks : # append by order
206
213
varname , block_id , _ = b .split (":" )
207
214
send_inputs .append (grad_var_mapping [varname ][int (block_id )])
208
-
209
- param_var_mapping = self ._create_vars_from_blocklist (program ,
210
- param_blocks )
211
215
for b in param_blocks :
212
216
varname , block_id , _ = b .split (":" )
213
217
send_outputs .append (param_var_mapping [varname ][int (block_id )])
@@ -237,7 +241,7 @@ def transpile(self,
237
241
"RPCClient" : rpc_client_var },
238
242
attrs = {"endpoints" : pserver_endpoints ,
239
243
"epmap" : eplist })
240
- # step4
244
+ # step4: Concat the parameters splits together after recv.
241
245
for varname , splited_var in param_var_mapping .iteritems ():
242
246
if len (splited_var ) <= 1 :
243
247
continue
@@ -258,13 +262,14 @@ def get_trainer_program(self):
258
262
def get_pserver_program (self , endpoint ):
259
263
"""
260
264
Get pserver side program using the endpoint.
265
+ TODO(panyx0718): Revisit this assumption. what if #blocks > #pservers.
261
266
NOTE: assume blocks of the same variable is not distributed
262
267
on the same pserver, only change param/grad varnames for
263
268
trainers to fetch.
264
269
"""
265
270
# step1
266
271
pserver_program = Program ()
267
- # step2
272
+ # step2: Create vars to receive vars at parameter servers.
268
273
recv_inputs = []
269
274
for v in self .param_grad_ep_mapping [endpoint ]["params" ]:
270
275
self ._clone_var (pserver_program .global_block (), v )
@@ -278,12 +283,6 @@ def get_pserver_program(self, endpoint):
278
283
orig_var_name = v .name [:suff_idx ]
279
284
else :
280
285
orig_var_name = v .name
281
- single_trainer_var = pserver_program .global_block ().create_var (
282
- name = orig_var_name ,
283
- persistable = True ,
284
- type = v .type ,
285
- dtype = v .dtype ,
286
- shape = v .shape )
287
286
if self .trainers > 1 :
288
287
for trainer_id in xrange (self .trainers ):
289
288
var = pserver_program .global_block ().create_var (
@@ -294,6 +293,12 @@ def get_pserver_program(self, endpoint):
294
293
shape = v .shape )
295
294
recv_inputs .append (var )
296
295
else :
296
+ single_trainer_var = pserver_program .global_block ().create_var (
297
+ name = orig_var_name ,
298
+ persistable = True ,
299
+ type = v .type ,
300
+ dtype = v .dtype ,
301
+ shape = v .shape )
297
302
recv_inputs .append (single_trainer_var )
298
303
299
304
# step3
@@ -344,7 +349,7 @@ def __append_optimize_op__(op, block):
344
349
self ._append_pserver_non_opt_ops (block , op )
345
350
346
351
append_block = optimize_block
347
- # append lr decay ops to the child block if exits
352
+ # append lr decay ops to the child block if exists
348
353
lr_ops = self ._get_lr_ops ()
349
354
if len (lr_ops ) > 0 :
350
355
for _ , op in enumerate (lr_ops ):
@@ -447,8 +452,10 @@ def _create_vars_from_blocklist(self,
447
452
block_list ,
448
453
add_trainer_suffix = False ):
449
454
"""
455
+ Create vars for each split.
450
456
NOTE: only grads need to be named for different trainers, use
451
457
add_trainer_suffix to rename the grad vars.
458
+ :return: A dict mapping from original var name to each var split.
452
459
"""
453
460
block_map = dict ()
454
461
var_mapping = dict ()
@@ -615,6 +622,7 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
615
622
type = "sum" ,
616
623
inputs = {"X" : vars2merge },
617
624
outputs = {"Out" : merged_var })
625
+ # TODO(panyx0718): What if it's SELECTED_ROWS.
618
626
if not merged_var .type == core .VarDesc .VarType .SELECTED_ROWS :
619
627
optimize_block .append_op (
620
628
type = "scale" ,
@@ -638,7 +646,7 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
638
646
shape = param_block .shape )
639
647
new_inputs [key ] = tmpvar
640
648
elif key == "LearningRate" :
641
- # leraning rate variable has already be created by non-optimize op,
649
+ # learning rate variable has already be created by non-optimize op,
642
650
# don't create it once again.
643
651
lr_varname = opt_op .input (key )[0 ]
644
652
if pserver_block .vars .has_key (lr_varname ):
@@ -773,6 +781,7 @@ def _is_opt_op_on_pserver(self, endpoint, op):
773
781
return False
774
782
775
783
def _get_input_map_from_op (self , varmap , op ):
784
+ """Returns a dict from op input name to the vars in varmap."""
776
785
iomap = dict ()
777
786
for key in op .input_names :
778
787
vars = []
@@ -785,6 +794,7 @@ def _get_input_map_from_op(self, varmap, op):
785
794
return iomap
786
795
787
796
def _get_output_map_from_op (self , varmap , op ):
797
+ """Returns a dict from op output name to the vars in varmap."""
788
798
iomap = dict ()
789
799
for key in op .output_names :
790
800
vars = []
@@ -812,6 +822,7 @@ def _get_lr_ops(self):
812
822
find_ops .append (op )
813
823
# make a union find struct by the ops in default_main_program
814
824
ufind = UnionFind (block .ops )
825
+
815
826
for op1 in block .ops :
816
827
for op2 in block .ops :
817
828
# NOTE: we need to skip all optimize ops, since it is connected
0 commit comments