15
15
from __future__ import print_function
16
16
17
17
import math
18
+ import numpy as np
18
19
19
20
from ps_dispatcher import RoundRobin , HashName , PSDispatcher
20
21
from .. import core , framework
@@ -103,15 +104,15 @@ def split_dense_variable(var_list, service_count, min_block_size=8192):
103
104
104
105
We need to have a minimal block size so that the calculations in
105
106
the parameter server side can gain better performance. By default
106
- minimum block size 8K elements (maybe 16bit or 32bit or 64bit).
107
+ minimum block size 8K elements (maybe 16bit or 32bit or 64bit).
107
108
108
109
Args:
109
110
var_list (list): List of variables.
110
111
service_count (int): Numel of pserver services. A pserver may have two
111
112
or more listening ports.
112
113
min_block_size (int): Minimum splitted block size.
113
114
Returns:
114
- blocks (list[(varname, block_id, current_block_size)]): A list
115
+ blocks (list[(varname, block_id, current_block_size)]): A list
115
116
of VarBlocks. Each VarBlock specifies a shard of the var.
116
117
"""
117
118
blocks = []
@@ -171,6 +172,7 @@ def transpile(self,
171
172
program = None ,
172
173
pservers = "127.0.0.1:6174" ,
173
174
trainers = 1 ,
175
+ align_var_to_block = True ,
174
176
split_method = RoundRobin ,
175
177
sync_mode = True ):
176
178
"""
@@ -183,7 +185,8 @@ def transpile(self,
183
185
parameter servers.
184
186
185
187
Steps to transpile trainer:
186
- 1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
188
+ 1. split variable to multiple blocks, aligned by product(dim[1:]) (width)
189
+ if align_var_to_block is True
187
190
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
188
191
3. modify trainer program add split_op to each grad variable.
189
192
4. append send_op to send splited variables to server and fetch
@@ -293,9 +296,18 @@ def transpile(self,
293
296
for index in range (len (self .pserver_endpoints ))
294
297
]
295
298
296
- grad_blocks = split_dense_variable (grad_list , len (pserver_endpoints ))
297
- param_blocks = split_dense_variable (param_list , len (pserver_endpoints ))
299
+ if align_var_to_block :
300
+ grad_blocks = split_dense_variable (grad_list ,
301
+ len (pserver_endpoints ))
302
+ param_blocks = split_dense_variable (param_list ,
303
+ len (pserver_endpoints ))
304
+ else :
305
+ # when we do NOT align var to block, we will always split params
306
+ # grads into one block.
307
+ grad_blocks = split_dense_variable (grad_list , 1 )
308
+ param_blocks = split_dense_variable (param_list , 1 )
298
309
assert (len (grad_blocks ) == len (param_blocks ))
310
+
299
311
# step2: Create new vars for the parameters and gradients blocks and
300
312
# add ops to do the split.
301
313
param_var_mapping = self ._create_vars_from_blocklist (program ,
@@ -325,8 +337,22 @@ def transpile(self,
325
337
# step 3.1: insert send op to send gradient vars to parameter servers
326
338
ps_dispatcher .reset ()
327
339
send_vars = []
328
- for orig_varname , splited_vars in grad_var_mapping .items ():
340
+
341
+ # in general cases, the number of pservers is times of 2, and this
342
+ # will lead to uneven distribution among weights and bias:
343
+ # fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1
344
+ # fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
345
+ # shuffle the map will avoid the uneven distribution above
346
+ grad_var_mapping_items = grad_var_mapping .items ()
347
+ if not align_var_to_block :
348
+ np .random .shuffle (grad_var_mapping_items )
349
+
350
+ for orig_varname , splited_vars in grad_var_mapping_items :
329
351
eplist = ps_dispatcher .dispatch (splited_vars )
352
+
353
+ if not align_var_to_block :
354
+ assert (len (splited_vars ) == 1 )
355
+
330
356
if len (splited_vars ) == 1 :
331
357
orig_varname = splited_vars [0 ].name
332
358
index = find_op_by_output_arg (program .global_block (),
@@ -374,7 +400,7 @@ def transpile(self,
374
400
for i , ep in enumerate (eplist ):
375
401
self .param_grad_ep_mapping [ep ]["params" ].append (recv_vars [i ])
376
402
self .param_grad_ep_mapping [ep ]["grads" ].append (send_vars [i ])
377
- # step4: Concat the parameters splits together after recv.
403
+
378
404
for varname , splited_var in param_var_mapping .iteritems ():
379
405
eps = []
380
406
for var in splited_var :
@@ -399,6 +425,7 @@ def transpile(self,
399
425
RPC_OP_ROLE_ATTR_NAME : RPC_OP_ROLE_ATTR_VALUE
400
426
})
401
427
428
+ # step4: Concat the parameters splits together after recv.
402
429
for varname , splited_var in param_var_mapping .iteritems ():
403
430
if len (splited_var ) <= 1 :
404
431
continue
@@ -849,8 +876,8 @@ def _create_vars_from_blocklist(self,
849
876
program (ProgramDesc): ProgramDesc which gradients blong.
850
877
block_list (list[(varname, block_id, block_size)]): List of gradient blocks.
851
878
add_trainer_suffix (Bool): Add trainer suffix to new variable's name if set True.
852
- Returns:
853
- var_mapping (dict(varname->[new_varname_variable])):A dict mapping
879
+ Returns:
880
+ var_mapping (dict(varname->[new_varname_variable])):A dict mapping
854
881
from original var name to each var split.
855
882
"""
856
883
0 commit comments