Skip to content

Commit 9ae55dd

Browse files
authored
fix dist transpile with memopt (#12974)
* fix dist transpile with memopt * update api.spec * polish dist transpile api * update apispec * update apispec
1 parent 902f19b commit 9ae55dd

File tree

2 files changed

+69
-25
lines changed

2 files changed

+69
-25
lines changed

paddle/fluid/API.spec

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ paddle.fluid.Inferencer.__init__ ArgSpec(args=['self', 'infer_func', 'param_path
5555
paddle.fluid.Inferencer.infer ArgSpec(args=['self', 'inputs', 'return_numpy'], varargs=None, keywords=None, defaults=(True,))
5656
paddle.fluid.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
5757
paddle.fluid.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
58-
paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,))
58+
paddle.fluid.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
59+
paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
5960
paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
60-
paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True))
61+
paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None))
6162
paddle.fluid.InferenceTranspiler.__init__
6263
paddle.fluid.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
6364
paddle.fluid.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0))
@@ -335,9 +336,10 @@ paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array
335336
paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None)
336337
paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
337338
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
338-
paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,))
339+
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
340+
paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
339341
paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
340-
paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True))
342+
paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None))
341343
paddle.fluid.transpiler.InferenceTranspiler.__init__
342344
paddle.fluid.transpiler.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
343345
paddle.fluid.transpiler.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0))

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 63 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"""
3232

3333
import math
34+
import sys
3435
import numpy as np
3536
import collections
3637
import six
@@ -181,7 +182,8 @@ def transpile(self,
181182
program=None,
182183
pservers="127.0.0.1:6174",
183184
trainers=1,
184-
sync_mode=True):
185+
sync_mode=True,
186+
startup_program=None):
185187
"""
186188
Run the transpiler.
187189
@@ -194,13 +196,17 @@ def transpile(self,
194196
list.
195197
trainers (int): number of trainers in the distributed job.
196198
sync_mode (bool): Do sync training or not, default is True.
199+
startup_program (Program|None): startup_program to transpile,
200+
default is fluid.default_main_program().
197201
"""
198202
if program is None:
199203
program = default_main_program()
204+
if startup_program is None:
205+
startup_program = default_startup_program()
200206
self.origin_program = program
201-
self.origin_startup_program = default_startup_program().clone()
207+
self.startup_program = startup_program
208+
self.origin_startup_program = self.startup_program.clone()
202209

203-
self.startup_program = default_startup_program()
204210
self.trainer_num = trainers
205211
self.sync_mode = sync_mode
206212
self.trainer_id = trainer_id
@@ -376,21 +382,18 @@ def get_trainer_program(self):
376382

377383
return self.origin_program
378384

379-
def _get_trainer_startup_program(self,
380-
recv_vars,
381-
eplist,
382-
startup_program=None):
385+
def _get_trainer_startup_program(self, recv_vars, eplist):
383386
"""
384387
Get transpiled trainer side startup program.
385388
386389
Args:
387-
startup_program(Program): Startup program.
390+
recv_vars (list): Variable list to recv for current trainer_id
391+
eplist (list): A list of strings indicating
388392
389393
Returns:
390394
Program: trainer side startup program.
391395
"""
392-
if startup_program is None:
393-
startup_program = self.startup_program
396+
startup_program = self.startup_program
394397

395398
# FIXME(gongwb): delete not need ops.
396399
# note that: some parameter is not trainable and those ops can't be deleted.
@@ -438,7 +441,18 @@ def _get_trainer_startup_program(self,
438441
#add concat ops to merge splited parameters received from parameter servers.
439442
if len(splited_var) <= 1:
440443
continue
441-
orig_param = startup_program.global_block().vars[varname]
444+
# NOTE: if enable memory optimization, origin vars maybe removed.
445+
if startup_program.global_block().vars.has_key(varname):
446+
orig_param = startup_program.global_block().vars[varname]
447+
else:
448+
origin_param_var = self.origin_program.global_block().vars[
449+
varname]
450+
orig_param = startup_program.global_block().create_var(
451+
name=varname,
452+
persistable=origin_param_var.persistable,
453+
type=origin_param_var.type,
454+
dtype=origin_param_var.dtype,
455+
shape=origin_param_var.shape)
442456
startup_program.global_block().append_op(
443457
type="concat",
444458
inputs={"X": splited_var},
@@ -461,7 +475,9 @@ def get_pserver_program(self, endpoint):
461475
# NOTE: assume blocks of the same variable is not distributed
462476
# on the same pserver, only change param/grad varnames for
463477
# trainers to fetch.
464-
478+
sys.stderr.write("get_pserver_program() is deprecated, call\
479+
get_pserver_programs() to get pserver main and startup\
480+
in a single call.")
465481
# step1
466482
pserver_program = Program()
467483
pserver_program.random_seed = self.origin_program.random_seed
@@ -651,32 +667,58 @@ def __clone_lr_op_sub_block__(op, program, lr_block):
651667
endpoint)
652668

653669
pserver_program._sync_with_cpp()
670+
# save pserver program to generate pserver side startup relatively.
671+
self.pserver_program = pserver_program
654672
return pserver_program
655673

674+
def get_pserver_programs(self, endpoint):
675+
"""
676+
Get pserver side main program and startup program for distributed training.
677+
678+
Args:
679+
endpoint (str): current pserver endpoint.
680+
681+
Returns:
682+
tuple: (main_program, startup_program), of type "Program"
683+
"""
684+
pserver_prog = self.get_pserver_program(endpoint)
685+
pserver_startup = self.get_startup_program(endpoint)
686+
return pserver_prog, pserver_startup
687+
656688
def get_startup_program(self,
657689
endpoint,
658-
pserver_program,
690+
pserver_program=None,
659691
startup_program=None):
660692
"""
693+
**Deprecated**
694+
661695
Get startup program for current parameter server.
662696
Modify operator input variables if there are variables that
663697
were split to several blocks.
664698
665699
Args:
666700
endpoint (str): current pserver endpoint.
667-
pserver_program (Program): call get_pserver_program first and
668-
pass the result here.
669-
startup_program (Program): if pass None, will use
670-
default_startup_program
701+
pserver_program (Program): deprecated, call get_pserver_program first.
702+
startup_program (Program): deprecated, should pass startup_program
703+
when initalizing
671704
672705
Returns:
673706
Program: parameter server side startup program.
674707
"""
708+
sys.stderr.write("get_startup_program() is deprecated, call\
709+
get_pserver_programs() to get pserver main and startup\
710+
in a single call.")
711+
if pserver_program != None:
712+
sys.stderr.write("passing pserver_program to get_startup_program()\
713+
is deprecated, you can use new API get_pserver_programs() to\
714+
get both pserver main program and startup program.")
715+
if startup_program != None:
716+
sys.stderr.write("passing startup_program to get_startup_program()\
717+
is deprecated, use fluid.program_guard() or pass this argument\
718+
to transpile() call.")
719+
675720
s_prog = Program()
676-
if not startup_program:
677-
orig_s_prog = default_startup_program()
678-
else:
679-
orig_s_prog = startup_program
721+
orig_s_prog = self.startup_program
680722
s_prog.random_seed = orig_s_prog.random_seed
681723
params = self.param_grad_ep_mapping[endpoint]["params"]
682724

0 commit comments

Comments
 (0)