31
31
"""
32
32
33
33
import math
34
+ import sys
34
35
import numpy as np
35
36
import collections
36
37
import six
@@ -181,7 +182,8 @@ def transpile(self,
181
182
program = None ,
182
183
pservers = "127.0.0.1:6174" ,
183
184
trainers = 1 ,
184
- sync_mode = True ):
185
+ sync_mode = True ,
186
+ startup_program = None ):
185
187
"""
186
188
Run the transpiler.
187
189
@@ -194,13 +196,17 @@ def transpile(self,
194
196
list.
195
197
trainers (int): number of trainers in the distributed job.
196
198
sync_mode (bool): Do sync training or not, default is True.
199
+ startup_program (Program|None): startup_program to transpile,
200
+ default is fluid.default_main_program().
197
201
"""
198
202
if program is None :
199
203
program = default_main_program ()
204
+ if startup_program is None :
205
+ startup_program = default_startup_program ()
200
206
self .origin_program = program
201
- self .origin_startup_program = default_startup_program ().clone ()
207
+ self .startup_program = startup_program
208
+ self .origin_startup_program = self .startup_program .clone ()
202
209
203
- self .startup_program = default_startup_program ()
204
210
self .trainer_num = trainers
205
211
self .sync_mode = sync_mode
206
212
self .trainer_id = trainer_id
@@ -376,21 +382,18 @@ def get_trainer_program(self):
376
382
377
383
return self .origin_program
378
384
379
- def _get_trainer_startup_program (self ,
380
- recv_vars ,
381
- eplist ,
382
- startup_program = None ):
385
+ def _get_trainer_startup_program (self , recv_vars , eplist ):
383
386
"""
384
387
Get transpiled trainer side startup program.
385
388
386
389
Args:
387
- startup_program(Program): Startup program.
390
+ recv_vars (list): Variable list to recv for current trainer_id
391
+ eplist (list): A list of strings indicating
388
392
389
393
Returns:
390
394
Program: trainer side startup program.
391
395
"""
392
- if startup_program is None :
393
- startup_program = self .startup_program
396
+ startup_program = self .startup_program
394
397
395
398
# FIXME(gongwb): delete not need ops.
396
399
# note that: some parameter is not trainable and those ops can't be deleted.
@@ -438,7 +441,18 @@ def _get_trainer_startup_program(self,
438
441
#add concat ops to merge splited parameters received from parameter servers.
439
442
if len (splited_var ) <= 1 :
440
443
continue
441
- orig_param = startup_program .global_block ().vars [varname ]
444
+ # NOTE: if enable memory optimization, origin vars maybe removed.
445
+ if startup_program .global_block ().vars .has_key (varname ):
446
+ orig_param = startup_program .global_block ().vars [varname ]
447
+ else :
448
+ origin_param_var = self .origin_program .global_block ().vars [
449
+ varname ]
450
+ orig_param = startup_program .global_block ().create_var (
451
+ name = varname ,
452
+ persistable = origin_param_var .persistable ,
453
+ type = origin_param_var .type ,
454
+ dtype = origin_param_var .dtype ,
455
+ shape = origin_param_var .shape )
442
456
startup_program .global_block ().append_op (
443
457
type = "concat" ,
444
458
inputs = {"X" : splited_var },
@@ -461,7 +475,9 @@ def get_pserver_program(self, endpoint):
461
475
# NOTE: assume blocks of the same variable is not distributed
462
476
# on the same pserver, only change param/grad varnames for
463
477
# trainers to fetch.
464
-
478
+ sys .stderr .write ("get_pserver_program() is deprecated, call\
479
+ get_pserver_programs() to get pserver main and startup\
480
+ in a single call." )
465
481
# step1
466
482
pserver_program = Program ()
467
483
pserver_program .random_seed = self .origin_program .random_seed
@@ -651,32 +667,58 @@ def __clone_lr_op_sub_block__(op, program, lr_block):
651
667
endpoint )
652
668
653
669
pserver_program ._sync_with_cpp ()
670
+ # save pserver program to generate pserver side startup relatively.
671
+ self .pserver_program = pserver_program
654
672
return pserver_program
655
673
674
+ def get_pserver_programs (self , endpoint ):
675
+ """
676
+ Get pserver side main program and startup program for distributed training.
677
+
678
+ Args:
679
+ endpoint (str): current pserver endpoint.
680
+
681
+ Returns:
682
+ tuple: (main_program, startup_program), of type "Program"
683
+ """
684
+ pserver_prog = self .get_pserver_program (endpoint )
685
+ pserver_startup = self .get_startup_program (endpoint )
686
+ return pserver_prog , pserver_startup
687
+
656
688
def get_startup_program (self ,
657
689
endpoint ,
658
- pserver_program ,
690
+ pserver_program = None ,
659
691
startup_program = None ):
660
692
"""
693
+ **Deprecated**
694
+
661
695
Get startup program for current parameter server.
662
696
Modify operator input variables if there are variables that
663
697
were split to several blocks.
664
698
665
699
Args:
666
700
endpoint (str): current pserver endpoint.
667
- pserver_program (Program): call get_pserver_program first and
668
- pass the result here.
669
- startup_program (Program): if pass None, will use
670
- default_startup_program
701
+ pserver_program (Program): deprecated, call get_pserver_program first.
702
+ startup_program (Program): deprecated, should pass startup_program
703
+ when initalizing
671
704
672
705
Returns:
673
706
Program: parameter server side startup program.
674
707
"""
708
+ sys .stderr .write ("get_startup_program() is deprecated, call\
709
+ get_pserver_programs() to get pserver main and startup\
710
+ in a single call." )
711
+ if pserver_program != None :
712
+ sys .stderr .write ("passing pserver_program to get_startup_program()\
713
+ is deprecated, you can use new API get_pserver_programs() to\
714
+ get both pserver main program and startup program." )
715
+ if startup_program != None :
716
+ sys .stderr .write ("passing startup_program to get_startup_program()\
717
+ is deprecated, use fluid.program_guard() or pass this argument\
718
+ to transpile() call." )
719
+
675
720
s_prog = Program ()
676
- if not startup_program :
677
- orig_s_prog = default_startup_program ()
678
- else :
679
- orig_s_prog = startup_program
721
+ orig_s_prog = self .startup_program
680
722
s_prog .random_seed = orig_s_prog .random_seed
681
723
params = self .param_grad_ep_mapping [endpoint ]["params" ]
682
724
0 commit comments