31
31
"""
32
32
33
33
import math
34
- import random
34
+ import sys
35
35
import numpy as np
36
36
import collections
37
37
@@ -181,7 +181,8 @@ def transpile(self,
181
181
program = None ,
182
182
pservers = "127.0.0.1:6174" ,
183
183
trainers = 1 ,
184
- sync_mode = True ):
184
+ sync_mode = True ,
185
+ startup_program = None ):
185
186
"""
186
187
Run the transpiler.
187
188
@@ -194,13 +195,17 @@ def transpile(self,
194
195
list.
195
196
trainers (int): number of trainers in the distributed job.
196
197
sync_mode (bool): Do sync training or not, default is True.
198
+ startup_program (Program|None): startup_program to transpile,
199
+ default is fluid.default_main_program().
197
200
"""
198
201
if program is None :
199
202
program = default_main_program ()
203
+ if startup_program is None :
204
+ startup_program = default_startup_program ()
200
205
self .origin_program = program
201
- self .origin_startup_program = default_startup_program ().clone ()
206
+ self .startup_program = startup_program
207
+ self .origin_startup_program = self .startup_program .clone ()
202
208
203
- self .startup_program = default_startup_program ()
204
209
self .trainer_num = trainers
205
210
self .sync_mode = sync_mode
206
211
self .trainer_id = trainer_id
@@ -369,21 +374,18 @@ def get_trainer_program(self):
369
374
370
375
return self .origin_program
371
376
372
- def _get_trainer_startup_program (self ,
373
- recv_vars ,
374
- eplist ,
375
- startup_program = None ):
377
+ def _get_trainer_startup_program (self , recv_vars , eplist ):
376
378
"""
377
379
Get transpiled trainer side startup program.
378
380
379
381
Args:
380
- startup_program(Program): Startup program.
382
+ recv_vars (list): Variable list to recv for current trainer_id
383
+ eplist (list): A list of strings indicating
381
384
382
385
Returns:
383
386
Program: trainer side startup program.
384
387
"""
385
- if startup_program is None :
386
- startup_program = self .startup_program
388
+ startup_program = self .startup_program
387
389
388
390
# FIXME(gongwb): delete not need ops.
389
391
# note that: some parameter is not trainable and those ops can't be deleted.
@@ -431,7 +433,18 @@ def _get_trainer_startup_program(self,
431
433
#add concat ops to merge splited parameters received from parameter servers.
432
434
if len (splited_var ) <= 1 :
433
435
continue
434
- orig_param = startup_program .global_block ().vars [varname ]
436
+ # NOTE: if enable memory optimization, origin vars maybe removed.
437
+ if startup_program .global_block ().vars .has_key (varname ):
438
+ orig_param = startup_program .global_block ().vars [varname ]
439
+ else :
440
+ origin_param_var = self .origin_program .global_block ().vars [
441
+ varname ]
442
+ orig_param = startup_program .global_block ().create_var (
443
+ name = varname ,
444
+ persistable = origin_param_var .persistable ,
445
+ type = origin_param_var .type ,
446
+ dtype = origin_param_var .dtype ,
447
+ shape = origin_param_var .shape )
435
448
startup_program .global_block ().append_op (
436
449
type = "concat" ,
437
450
inputs = {"X" : splited_var },
@@ -454,7 +467,9 @@ def get_pserver_program(self, endpoint):
454
467
# NOTE: assume blocks of the same variable is not distributed
455
468
# on the same pserver, only change param/grad varnames for
456
469
# trainers to fetch.
457
-
470
+ sys .stderr .write ("get_pserver_program() is deprecated, call\
471
+ get_pserver_programs() to get pserver main and startup\
472
+ in a single call." )
458
473
# step1
459
474
pserver_program = Program ()
460
475
pserver_program .random_seed = self .origin_program .random_seed
@@ -638,32 +653,58 @@ def __clone_lr_op_sub_block__(op, program, lr_block):
638
653
attrs = attrs )
639
654
640
655
pserver_program ._sync_with_cpp ()
656
+ # save pserver program to generate pserver side startup relatively.
657
+ self .pserver_program = pserver_program
641
658
return pserver_program
642
659
660
+ def get_pserver_programs (self , endpoint ):
661
+ """
662
+ Get pserver side main program and startup program for distributed training.
663
+
664
+ Args:
665
+ endpoint (str): current pserver endpoint.
666
+
667
+ Returns:
668
+ tuple: (main_program, startup_program), of type "Program"
669
+ """
670
+ pserver_prog = self .get_pserver_program (endpoint )
671
+ pserver_startup = self .get_startup_program (endpoint )
672
+ return pserver_prog , pserver_startup
673
+
643
674
def get_startup_program (self ,
644
675
endpoint ,
645
- pserver_program ,
676
+ pserver_program = None ,
646
677
startup_program = None ):
647
678
"""
679
+ **Deprecated**
680
+
648
681
Get startup program for current parameter server.
649
682
Modify operator input variables if there are variables that
650
683
were split to several blocks.
651
684
652
685
Args:
653
686
endpoint (str): current pserver endpoint.
654
- pserver_program (Program): call get_pserver_program first and
655
- pass the result here.
656
- startup_program (Program): if pass None, will use
657
- default_startup_program
687
+ pserver_program (Program): deprecated, call get_pserver_program first.
688
+ startup_program (Program): deprecated, should pass startup_program
689
+ when initalizing
658
690
659
691
Returns:
660
692
Program: parameter server side startup program.
661
693
"""
694
+ sys .stderr .write ("get_startup_program() is deprecated, call\
695
+ get_pserver_programs() to get pserver main and startup\
696
+ in a single call." )
697
+ if pserver_program != None :
698
+ sys .stderr .write ("passing pserver_program to get_startup_program()\
699
+ is deprecated, you can use new API get_pserver_programs() to\
700
+ get both pserver main program and startup program." )
701
+ if startup_program != None :
702
+ sys .stderr .write ("passing startup_program to get_startup_program()\
703
+ is deprecated, use fluid.program_guard() or pass this argument\
704
+ to transpile() call." )
705
+
662
706
s_prog = Program ()
663
- if not startup_program :
664
- orig_s_prog = default_startup_program ()
665
- else :
666
- orig_s_prog = startup_program
707
+ orig_s_prog = self .startup_program
667
708
s_prog .random_seed = orig_s_prog .random_seed
668
709
params = self .param_grad_ep_mapping [endpoint ]["params" ]
669
710
0 commit comments