12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
"""
15
- Transpile the program to distributed data-parallelism programs.
16
- The main_program will be transformed to use a remote parameter server
17
- to do parameter optimization. And the optimization graph will be put
18
- into a parameter server program.
19
-
20
- Use different methods to split trainable variables to different
21
- parameter servers.
22
-
23
15
Steps to transpile trainer:
24
16
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
25
17
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
@@ -118,128 +110,40 @@ def slice_variable(var_list, slice_count, min_block_size=8192):
118
110
119
111
120
112
class DistributeTranspiler :
121
- def _has_distributed_lookup_table (self ):
122
- # process lookup_table_op
123
- # 1. check all lookup_table_op is distributed
124
- # 2. check all lookup_table_op share the same table.
125
- distributed_lookup_table_ops = []
126
- # support only one distributed_lookup_table now
127
- self .table_name = None
128
- for op in self .origin_program .global_block ().ops :
129
- if op .type == LOOKUP_TABLE_TYPE :
130
- if op .attrs ['is_distributed' ] is True :
131
- if self .table_name is None :
132
- self .table_name = op .input ("W" )[0 ]
133
- if self .table_name != op .input ("W" )[0 ]:
134
- raise RuntimeError ("all distributed lookup_table_ops"
135
- " should have only one table" )
136
- distributed_lookup_table_ops .append (op )
137
- else :
138
- if self .table_name is not None :
139
- assert op .input ("W" )[0 ] != self .table_name
140
-
141
- return len (distributed_lookup_table_ops ) > 0
142
-
143
- def _update_dist_lookup_table_vars (self , param_list , grad_list ,
144
- params_grads ):
145
- # TODO(wuyi): put find a way to put dist lookup table stuff all together.
146
- # update self.table_param_grad and self.trainer_side_table_grad_list
147
- program = self .origin_program
148
- if self .has_distributed_lookup_table :
149
- param_list = [
150
- param for param in param_list if param .name != self .table_name
151
- ]
152
- grad_list = [
153
- grad for grad in grad_list
154
- if grad .name != grad_var_name (self .table_name )
155
- ]
156
- self .table_param_grad = [
157
- param_grad for param_grad in params_grads
158
- if param_grad [0 ].name == self .table_name
159
- ][0 ]
160
- table_grad_var = self .table_param_grad [1 ]
161
- if self .sync_mode :
162
- self .trainer_side_table_grad_list = [
163
- program .global_block ().create_var (
164
- name = "%s.trainer_%d.pserver_%d" %
165
- (table_grad_var .name , self .trainer_id , index ),
166
- type = table_grad_var .type ,
167
- shape = table_grad_var .shape ,
168
- dtype = table_grad_var .dtype )
169
- for index in range (len (self .pserver_endpoints ))
170
- ]
171
- else :
172
- self .trainer_side_table_grad_list = [
173
- program .global_block ().create_var (
174
- name = "%s.pserver_%d" % (table_grad_var .name , index ),
175
- type = table_grad_var .type ,
176
- shape = table_grad_var .shape ,
177
- dtype = table_grad_var .dtype )
178
- for index in range (len (self .pserver_endpoints ))
179
- ]
180
- return param_list , grad_list
181
-
182
- def _init_splited_vars (self , slice_var_up ):
183
- # update these mappings for further transpile:
184
- # 1. param_var_mapping: param var name -> [splited params vars]
185
- # 2. grad_var_mapping: grad var name -> [splited grads vars]
186
- # 3. grad_param_mapping: grad.blockx -> param.blockx
187
- # 4. param_grad_ep_mapping: ep -> {"params": [], "grads": []}
188
-
189
- param_list = []
190
- grad_list = []
191
- param_grad_set = set ()
192
- for p , g in self .params_grads :
193
- # skip parameter marked not trainable
194
- if type (p ) == Parameter and p .trainable == False :
195
- continue
196
- if p .name not in param_grad_set :
197
- param_list .append (p )
198
- param_grad_set .add (p .name )
199
- if g .name not in param_grad_set :
200
- grad_list .append (g )
201
- param_grad_set .add (g .name )
202
-
203
- param_list , grad_list = self ._update_dist_lookup_table_vars (
204
- param_list , grad_list , self .params_grads )
205
-
206
- if slice_var_up :
207
- # when we slice var up into blocks, we will slice the var according to
208
- # pserver services' count. A pserver may have two or more listening ports.
209
- grad_blocks = slice_variable (grad_list , len (self .pserver_endpoints ))
210
- param_blocks = slice_variable (param_list ,
211
- len (self .pserver_endpoints ))
212
- else :
213
- # when we do NOT slice var up into blocks, we will always slice params
214
- # grads into one block.
215
- grad_blocks = slice_variable (grad_list , 1 )
216
- param_blocks = slice_variable (param_list , 1 )
217
- assert (len (grad_blocks ) == len (param_blocks ))
218
-
219
- # origin_varname -> [splited_var]
220
- self .param_var_mapping = self ._create_vars_from_blocklist (
221
- self .origin_program , param_blocks )
222
- self .grad_var_mapping = self ._create_vars_from_blocklist (
223
- self .origin_program ,
224
- grad_blocks ,
225
- add_trainer_suffix = self .trainer_num > 1 )
226
- self .grad_param_mapping = dict ()
227
- for g , p in zip (grad_blocks , param_blocks ):
228
- g_name , g_bid , _ = g .split (":" )
229
- p_name , p_bid , _ = p .split (":" )
230
- self .grad_param_mapping [self .grad_var_mapping [g_name ][int (g_bid )]] = \
231
- self .param_var_mapping [p_name ][int (p_bid )]
232
-
233
- # create mapping of endpoint -> split var to create pserver side program
234
- self .param_grad_ep_mapping = dict ()
235
- [
236
- self .param_grad_ep_mapping .update ({
237
- ep : {
238
- "params" : [],
239
- "grads" : []
240
- }
241
- }) for ep in self .pserver_endpoints
242
- ]
113
+ """
114
+ **DistributeTranspiler**
115
+
116
+ Convert the fluid program to distributed data-parallelism programs.
117
+
118
+ The main_program will be transformed to use a remote parameter server
119
+ to do parameter optimization. And the optimization graph will be put
120
+ into a parameter server program.
121
+
122
+ Examples:
123
+ .. code-block:: python
124
+
125
+ # Define your model before these codes.
126
+ port = os.getenv("PADDLE_PSERVER_PORT", "6174")
127
+ pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "")
128
+ eplist = []
129
+ for ip in pserver_ips.split(","):
130
+ eplist.append(':'.join([ip, port]))
131
+ pserver_endpoints = ",".join(eplist)
132
+ trainers = int(os.getenv("PADDLE_TRAINERS"))
133
+ current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
134
+ trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
135
+ role = os.getenv("PADDLE_TRAINING_ROLE")
136
+
137
+ t = distribute_transpiler.DistributeTranspiler()
138
+ t.transpile(
139
+ trainer_id, pservers=pserver_endpoints, trainers=trainers)
140
+ if role == "PSERVER":
141
+ pserver_program = t.get_pserver_program(current_endpoint)
142
+ pserver_startup_program = t.get_startup_program(current_endpoint,
143
+ pserver_program)
144
+ elif role == "TRAINER":
145
+ trainer_program = t.get_trainer_program()
146
+ """
243
147
244
148
def transpile (self ,
245
149
trainer_id ,
@@ -250,20 +154,20 @@ def transpile(self,
250
154
split_method = RoundRobin ,
251
155
sync_mode = True ):
252
156
"""
253
- :param trainer_id: one unique id for each trainer in a job .
254
- :type trainer_id: int
255
- :param program: program to transpile, default is default_main_program
256
- :type program: Program
257
- :param pservers: parameter server endpoints like "m1:6174,m2:6174"
258
- :type pservers: string
259
- :param trainers: total number of workers/trainers in the job
260
- :type trainers: int
261
- :param split_method: A function to determin how to split variables
262
- to different servers equally .
263
- :type split_method: function
264
- :type sync_mode: boolean default True
265
- :param sync_mode: if sync_mode is set True, it means that dist transpiler
266
- will transpile the program into sync_mode pserver and trainer program .
157
+ Run the transpiler .
158
+
159
+ Args:
160
+ trainer_id (int): id for current trainer worker, if you have
161
+ n workers, the id may range from 0 ~ n-1
162
+ program (Program|None): program to transpile,
163
+ default is fluid.default_main_program().
164
+ pservers (str): comma separated ip:port string for the pserver
165
+ list.
166
+ trainers (int): number of trainers in the distributed job .
167
+ slice_var_up (bool): Do Tensor slice for pservers, default is True.
168
+ split_method (PSDispatcher): RoundRobin or HashName can be used
169
+ try to choose the best method to balance loads for pservers.
170
+ sync_mode (bool): Do sync training or not, default is True .
267
171
"""
268
172
assert (split_method .__bases__ [0 ] == PSDispatcher )
269
173
if program is None :
@@ -390,6 +294,12 @@ def transpile(self,
390
294
self ._split_table_grad_and_add_send_vars (program , pserver_endpoints )
391
295
392
296
def get_trainer_program (self ):
297
+ """
298
+ Get transpiled trainer side program.
299
+
300
+ Returns:
301
+ Program: trainer side program.
302
+ """
393
303
# remove optimize ops and add a send op to main_program
394
304
delete_ops (self .origin_program .global_block (), self .optimize_ops )
395
305
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
@@ -398,12 +308,19 @@ def get_trainer_program(self):
398
308
399
309
def get_pserver_program (self , endpoint ):
400
310
"""
401
- Get pserver side program using the endpoint.
402
- TODO(panyx0718): Revisit this assumption. what if #blocks > #pservers.
403
- NOTE: assume blocks of the same variable is not distributed
404
- on the same pserver, only change param/grad varnames for
405
- trainers to fetch.
311
+ Get parameter server side program.
312
+
313
+ Args:
314
+ endpoint (str): current parameter server endpoint.
315
+
316
+ Returns:
317
+ Program: the program for current parameter server to run.
406
318
"""
319
+ # TODO(panyx0718): Revisit this assumption. what if #blocks > #pservers.
320
+ # NOTE: assume blocks of the same variable is not distributed
321
+ # on the same pserver, only change param/grad varnames for
322
+ # trainers to fetch.
323
+
407
324
# step1
408
325
pserver_program = Program ()
409
326
# step2: Create vars to receive vars at parameter servers.
@@ -556,6 +473,14 @@ def get_startup_program(self, endpoint, pserver_program):
556
473
Get startup program for current parameter server.
557
474
Modify operator input variables if there are variables that
558
475
were split to several blocks.
476
+
477
+ Args:
478
+ endpoint (str): current pserver endpoint.
479
+ pserver_program (Program): call get_pserver_program first and
480
+ pass the result here.
481
+
482
+ Returns:
483
+ Program: parameter server side startup program.
559
484
"""
560
485
s_prog = Program ()
561
486
orig_s_prog = default_startup_program ()
@@ -607,6 +532,129 @@ def _get_splited_name_and_shape(varname):
607
532
608
533
# ====================== private transpiler functions =====================
609
534
535
+ def _has_distributed_lookup_table (self ):
536
+ # process lookup_table_op
537
+ # 1. check all lookup_table_op is distributed
538
+ # 2. check all lookup_table_op share the same table.
539
+ distributed_lookup_table_ops = []
540
+ # support only one distributed_lookup_table now
541
+ self .table_name = None
542
+ for op in self .origin_program .global_block ().ops :
543
+ if op .type == LOOKUP_TABLE_TYPE :
544
+ if op .attrs ['is_distributed' ] is True :
545
+ if self .table_name is None :
546
+ self .table_name = op .input ("W" )[0 ]
547
+ if self .table_name != op .input ("W" )[0 ]:
548
+ raise RuntimeError ("all distributed lookup_table_ops"
549
+ " should have only one table" )
550
+ distributed_lookup_table_ops .append (op )
551
+ else :
552
+ if self .table_name is not None :
553
+ assert op .input ("W" )[0 ] != self .table_name
554
+
555
+ return len (distributed_lookup_table_ops ) > 0
556
+
557
+ def _update_dist_lookup_table_vars (self , param_list , grad_list ,
558
+ params_grads ):
559
+ # TODO(wuyi): put find a way to put dist lookup table stuff all together.
560
+ # update self.table_param_grad and self.trainer_side_table_grad_list
561
+ program = self .origin_program
562
+ if self .has_distributed_lookup_table :
563
+ param_list = [
564
+ param for param in param_list if param .name != self .table_name
565
+ ]
566
+ grad_list = [
567
+ grad for grad in grad_list
568
+ if grad .name != grad_var_name (self .table_name )
569
+ ]
570
+ self .table_param_grad = [
571
+ param_grad for param_grad in params_grads
572
+ if param_grad [0 ].name == self .table_name
573
+ ][0 ]
574
+ table_grad_var = self .table_param_grad [1 ]
575
+ if self .sync_mode :
576
+ self .trainer_side_table_grad_list = [
577
+ program .global_block ().create_var (
578
+ name = "%s.trainer_%d.pserver_%d" %
579
+ (table_grad_var .name , self .trainer_id , index ),
580
+ type = table_grad_var .type ,
581
+ shape = table_grad_var .shape ,
582
+ dtype = table_grad_var .dtype )
583
+ for index in range (len (self .pserver_endpoints ))
584
+ ]
585
+ else :
586
+ self .trainer_side_table_grad_list = [
587
+ program .global_block ().create_var (
588
+ name = "%s.pserver_%d" % (table_grad_var .name , index ),
589
+ type = table_grad_var .type ,
590
+ shape = table_grad_var .shape ,
591
+ dtype = table_grad_var .dtype )
592
+ for index in range (len (self .pserver_endpoints ))
593
+ ]
594
+ return param_list , grad_list
595
+
596
+ def _init_splited_vars (self , slice_var_up ):
597
+ # update these mappings for further transpile:
598
+ # 1. param_var_mapping: param var name -> [splited params vars]
599
+ # 2. grad_var_mapping: grad var name -> [splited grads vars]
600
+ # 3. grad_param_mapping: grad.blockx -> param.blockx
601
+ # 4. param_grad_ep_mapping: ep -> {"params": [], "grads": []}
602
+
603
+ param_list = []
604
+ grad_list = []
605
+ param_grad_set = set ()
606
+ for p , g in self .params_grads :
607
+ # skip parameter marked not trainable
608
+ if type (p ) == Parameter and p .trainable == False :
609
+ continue
610
+ if p .name not in param_grad_set :
611
+ param_list .append (p )
612
+ param_grad_set .add (p .name )
613
+ if g .name not in param_grad_set :
614
+ grad_list .append (g )
615
+ param_grad_set .add (g .name )
616
+
617
+ param_list , grad_list = self ._update_dist_lookup_table_vars (
618
+ param_list , grad_list , self .params_grads )
619
+
620
+ if slice_var_up :
621
+ # when we slice var up into blocks, we will slice the var according to
622
+ # pserver services' count. A pserver may have two or more listening ports.
623
+ grad_blocks = slice_variable (grad_list , len (self .pserver_endpoints ))
624
+ param_blocks = slice_variable (param_list ,
625
+ len (self .pserver_endpoints ))
626
+ else :
627
+ # when we do NOT slice var up into blocks, we will always slice params
628
+ # grads into one block.
629
+ grad_blocks = slice_variable (grad_list , 1 )
630
+ param_blocks = slice_variable (param_list , 1 )
631
+ assert (len (grad_blocks ) == len (param_blocks ))
632
+
633
+ # origin_varname -> [splited_var]
634
+ self .param_var_mapping = self ._create_vars_from_blocklist (
635
+ self .origin_program , param_blocks )
636
+ self .grad_var_mapping = self ._create_vars_from_blocklist (
637
+ self .origin_program ,
638
+ grad_blocks ,
639
+ add_trainer_suffix = self .trainer_num > 1 )
640
+ self .grad_param_mapping = dict ()
641
+ for g , p in zip (grad_blocks , param_blocks ):
642
+ g_name , g_bid , _ = g .split (":" )
643
+ p_name , p_bid , _ = p .split (":" )
644
+ self .grad_param_mapping [self .grad_var_mapping [g_name ][int (g_bid )]] = \
645
+ self .param_var_mapping [p_name ][int (p_bid )]
646
+
647
+ # create mapping of endpoint -> split var to create pserver side program
648
+ self .param_grad_ep_mapping = dict ()
649
+ [
650
+ self .param_grad_ep_mapping .update ({
651
+ ep : {
652
+ "params" : [],
653
+ "grads" : []
654
+ }
655
+ }) for ep in self .pserver_endpoints
656
+ ]
657
+
610
658
# transpiler function for dis lookup_table
611
659
def _replace_lookup_table_op_with_prefetch (self , program ,
612
660
pserver_endpoints ):
0 commit comments