Skip to content

Commit bd87f67

Browse files
authored
Dist transpile can pass startup program by argument (#12606)
* dist transpile can pass startup program by argument * update API.spec
1 parent 9333a62 commit bd87f67

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

paddle/fluid/API.spec

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ paddle.fluid.Inferencer.__init__ ArgSpec(args=['self', 'infer_func', 'param_path
5555
paddle.fluid.Inferencer.infer ArgSpec(args=['self', 'inputs', 'return_numpy'], varargs=None, keywords=None, defaults=(True,))
5656
paddle.fluid.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
5757
paddle.fluid.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
58-
paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program'], varargs=None, keywords=None, defaults=None)
58+
paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,))
5959
paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
6060
paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True))
6161
paddle.fluid.InferenceTranspiler.__init__
@@ -328,7 +328,7 @@ paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array
328328
paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None)
329329
paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
330330
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
331-
paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program'], varargs=None, keywords=None, defaults=None)
331+
paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,))
332332
paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
333333
paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True))
334334
paddle.fluid.transpiler.InferenceTranspiler.__init__

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,10 @@ def __clone_lr_op_sub_block__(op, program, lr_block):
530530
pserver_program._sync_with_cpp()
531531
return pserver_program
532532

533-
def get_startup_program(self, endpoint, pserver_program):
533+
def get_startup_program(self,
534+
endpoint,
535+
pserver_program,
536+
startup_program=None):
534537
"""
535538
Get startup program for current parameter server.
536539
Modify operator input variables if there are variables that
@@ -540,12 +543,17 @@ def get_startup_program(self, endpoint, pserver_program):
540543
endpoint (str): current pserver endpoint.
541544
pserver_program (Program): call get_pserver_program first and
542545
pass the result here.
546+
startup_program (Program): if pass None, will use
547+
default_startup_program
543548
544549
Returns:
545550
Program: parameter server side startup program.
546551
"""
547552
s_prog = Program()
548-
orig_s_prog = default_startup_program()
553+
if not startup_program:
554+
orig_s_prog = default_startup_program()
555+
else:
556+
orig_s_prog = startup_program
549557
s_prog.random_seed = orig_s_prog.random_seed
550558
params = self.param_grad_ep_mapping[endpoint]["params"]
551559

0 commit comments

Comments
 (0)