@@ -93,30 +93,33 @@ def same_or_split_var(p_name, var_name):
93
93
return p_name == var_name or p_name .startswith (var_name + ".block" )
94
94
95
95
96
- def split_dense_variable (var_list ,
97
- pserver_count ,
98
- min_block_size = 1024 ,
99
- max_block_size = 1048576 ):
96
+ def split_dense_variable (var_list , service_count , min_block_size = 8192 ):
100
97
"""
101
- We may need to split dense tensor to one or more blocks and put
102
- them equally onto parameter server. One block is a sub-tensor
103
- aligned by dim[0] of the tensor.
104
-
105
- We need to have a minimal block size so that the calculations in
106
- the parameter server side can gain better performance. By default
107
- minimum block size is 1024. The max block size is used to prevent
108
- very large blocks that may cause send error.
109
- :return: A list of VarBlocks. Each VarBlock specifies a shard of
110
- the var.
98
+ We may need to split dense tensor to one or more blocks and put
99
+ them equally onto parameter server. One block is a sub-tensor
100
+ aligned by dim[0] of the tensor.
101
+
102
+ We need to have a minimal block size so that the calculations in
103
+ the parameter server side can gain better performance. By default
104
+ minimum block size 8K elements (maybe 16bit or 32bit or 64bit).
105
+
106
+ Args:
107
+ var_list (list): List of variables.
108
+ service_count (int): Numel of pserver services. A pserver may have two
109
+ or more listening ports.
110
+ min_block_size (int): Minimum splitted block size.
111
+ Returns:
112
+ blocks (list[(varname, block_id, current_block_size)]): A list
113
+ of VarBlocks. Each VarBlock specifies a shard of the var.
111
114
"""
112
115
blocks = []
113
116
for var in var_list :
114
- split_count = pserver_count
117
+ split_count = service_count
115
118
var_numel = reduce (lambda x , y : x * y , var .shape )
116
119
max_pserver_count = int (math .floor (var_numel / float (min_block_size )))
117
120
if max_pserver_count == 0 :
118
121
max_pserver_count = 1
119
- if max_pserver_count < pserver_count :
122
+ if max_pserver_count < service_count :
120
123
split_count = max_pserver_count
121
124
block_size = int (math .ceil (var_numel / float (split_count )))
122
125
@@ -270,16 +273,19 @@ def transpile(self,
270
273
grad_var_mapping = self ._append_split_op (program , grad_blocks )
271
274
param_var_mapping = self ._create_vars_from_blocklist (program ,
272
275
param_blocks )
276
+
273
277
# step3: Add gradients as send op inputs and parameters as send
274
278
# op outputs.
275
279
send_inputs = []
276
280
send_outputs = []
277
281
for b in grad_blocks : # append by order
278
282
varname , block_id , _ = b .split (":" )
279
283
send_inputs .append (grad_var_mapping [varname ][int (block_id )])
284
+
280
285
for b in param_blocks :
281
286
varname , block_id , _ = b .split (":" )
282
287
send_outputs .append (param_var_mapping [varname ][int (block_id )])
288
+
283
289
# let send_op know which endpoint to send which var to, eplist has the same
284
290
# order as send_inputs.
285
291
eplist = split_method (send_inputs , pserver_endpoints )
@@ -751,9 +757,18 @@ def _create_vars_from_blocklist(self,
751
757
Create vars for each split.
752
758
NOTE: only grads need to be named for different trainers, use
753
759
add_trainer_suffix to rename the grad vars.
754
- :return: A dict mapping from original var name to each var split.
760
+ Args:
761
+ program (ProgramDesc): ProgramDesc which gradients blong.
762
+ block_list (list[(varname, block_id, block_size)]): List of gradient blocks.
763
+ add_trainer_suffix (Bool): Add trainer suffix to new variable's name if set True.
764
+ Returns:
765
+ var_mapping (dict(varname->[new_varname_variable])):A dict mapping
766
+ from original var name to each var split.
755
767
"""
768
+
769
+ # varname->[(block_id, current_block_size)]
756
770
block_map = dict ()
771
+
757
772
var_mapping = dict ()
758
773
for block_str in block_list :
759
774
varname , offset , size = block_str .split (":" )
@@ -824,7 +839,16 @@ def _clone_var(self, block, var, persistable=True):
824
839
persistable = persistable )
825
840
826
841
def _append_split_op (self , program , gradblocks ):
827
- # Split variables that need to be split and append respective ops
842
+ """
843
+ Split variables that need to be split and append respective ops
844
+ Args:
845
+ program (ProgramDesc): ProgramDesc that gradients blong.
846
+ gradblocks (list[(varname, block_id, block_size)]): List of gradient blocks.
847
+ Returns:
848
+ var_mapping (dict(varname->[new_splitted_variable])):A dict mapping
849
+ from original var name to each var split.
850
+ """
851
+
828
852
add_suffix = False
829
853
if self .trainer_num > 1 :
830
854
add_suffix = True
@@ -1148,6 +1172,12 @@ def _get_lr_ops(self):
1148
1172
return lr_ops
1149
1173
1150
1174
def _get_optimize_pass (self ):
1175
+ """
1176
+ Get optimizer operators, paramters and gradients from origin_program
1177
+ Returns:
1178
+ opt_ops (list): optimize operators.
1179
+ params_grads (dict): paramter->gradient.
1180
+ """
1151
1181
block = self .origin_program .global_block ()
1152
1182
opt_ops = []
1153
1183
params_grads = []
0 commit comments