Skip to content

Commit f517d01

Browse files
committed
fix dist transpile with memopt (#12974)
* fix dist transpile with memopt * update api.spec * polish dist transpile api * update apispec * update apispec
1 parent a6dcadc commit f517d01

File tree

2 files changed

+69
-26
lines changed

2 files changed

+69
-26
lines changed

paddle/fluid/API.spec

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ paddle.fluid.Inferencer.__init__ ArgSpec(args=['self', 'infer_func', 'param_path
5555
paddle.fluid.Inferencer.infer ArgSpec(args=['self', 'inputs', 'return_numpy'], varargs=None, keywords=None, defaults=(True,))
5656
paddle.fluid.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
5757
paddle.fluid.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
58-
paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,))
58+
paddle.fluid.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
59+
paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
5960
paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
60-
paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True))
61+
paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None))
6162
paddle.fluid.InferenceTranspiler.__init__
6263
paddle.fluid.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
6364
paddle.fluid.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0))
@@ -329,9 +330,10 @@ paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array
329330
paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None)
330331
paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
331332
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
332-
paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,))
333+
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
334+
paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
333335
paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
334-
paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True))
336+
paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None))
335337
paddle.fluid.transpiler.InferenceTranspiler.__init__
336338
paddle.fluid.transpiler.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
337339
paddle.fluid.transpiler.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0))

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"""
3232

3333
import math
34-
import random
34+
import sys
3535
import numpy as np
3636
import collections
3737

@@ -181,7 +181,8 @@ def transpile(self,
181181
program=None,
182182
pservers="127.0.0.1:6174",
183183
trainers=1,
184-
sync_mode=True):
184+
sync_mode=True,
185+
startup_program=None):
185186
"""
186187
Run the transpiler.
187188
@@ -194,13 +195,17 @@ def transpile(self,
194195
list.
195196
trainers (int): number of trainers in the distributed job.
196197
sync_mode (bool): Do sync training or not, default is True.
198+
startup_program (Program|None): startup_program to transpile,
199+
default is fluid.default_main_program().
197200
"""
198201
if program is None:
199202
program = default_main_program()
203+
if startup_program is None:
204+
startup_program = default_startup_program()
200205
self.origin_program = program
201-
self.origin_startup_program = default_startup_program().clone()
206+
self.startup_program = startup_program
207+
self.origin_startup_program = self.startup_program.clone()
202208

203-
self.startup_program = default_startup_program()
204209
self.trainer_num = trainers
205210
self.sync_mode = sync_mode
206211
self.trainer_id = trainer_id
@@ -369,21 +374,18 @@ def get_trainer_program(self):
369374

370375
return self.origin_program
371376

372-
def _get_trainer_startup_program(self,
373-
recv_vars,
374-
eplist,
375-
startup_program=None):
377+
def _get_trainer_startup_program(self, recv_vars, eplist):
376378
"""
377379
Get transpiled trainer side startup program.
378380
379381
Args:
380-
startup_program(Program): Startup program.
382+
recv_vars (list): Variable list to recv for current trainer_id
383+
eplist (list): A list of strings indicating
381384
382385
Returns:
383386
Program: trainer side startup program.
384387
"""
385-
if startup_program is None:
386-
startup_program = self.startup_program
388+
startup_program = self.startup_program
387389

388390
# FIXME(gongwb): delete not need ops.
389391
# note that: some parameter is not trainable and those ops can't be deleted.
@@ -431,7 +433,18 @@ def _get_trainer_startup_program(self,
431433
#add concat ops to merge splited parameters received from parameter servers.
432434
if len(splited_var) <= 1:
433435
continue
434-
orig_param = startup_program.global_block().vars[varname]
436+
# NOTE: if enable memory optimization, origin vars maybe removed.
437+
if startup_program.global_block().vars.has_key(varname):
438+
orig_param = startup_program.global_block().vars[varname]
439+
else:
440+
origin_param_var = self.origin_program.global_block().vars[
441+
varname]
442+
orig_param = startup_program.global_block().create_var(
443+
name=varname,
444+
persistable=origin_param_var.persistable,
445+
type=origin_param_var.type,
446+
dtype=origin_param_var.dtype,
447+
shape=origin_param_var.shape)
435448
startup_program.global_block().append_op(
436449
type="concat",
437450
inputs={"X": splited_var},
@@ -454,7 +467,9 @@ def get_pserver_program(self, endpoint):
454467
# NOTE: assume blocks of the same variable is not distributed
455468
# on the same pserver, only change param/grad varnames for
456469
# trainers to fetch.
457-
470+
sys.stderr.write("get_pserver_program() is deprecated, call\
471+
get_pserver_programs() to get pserver main and startup\
472+
in a single call.")
458473
# step1
459474
pserver_program = Program()
460475
pserver_program.random_seed = self.origin_program.random_seed
@@ -638,32 +653,58 @@ def __clone_lr_op_sub_block__(op, program, lr_block):
638653
attrs=attrs)
639654

640655
pserver_program._sync_with_cpp()
656+
# save pserver program to generate pserver side startup relatively.
657+
self.pserver_program = pserver_program
641658
return pserver_program
642659

660+
def get_pserver_programs(self, endpoint):
661+
"""
662+
Get pserver side main program and startup program for distributed training.
663+
664+
Args:
665+
endpoint (str): current pserver endpoint.
666+
667+
Returns:
668+
tuple: (main_program, startup_program), of type "Program"
669+
"""
670+
pserver_prog = self.get_pserver_program(endpoint)
671+
pserver_startup = self.get_startup_program(endpoint)
672+
return pserver_prog, pserver_startup
673+
643674
def get_startup_program(self,
644675
endpoint,
645-
pserver_program,
676+
pserver_program=None,
646677
startup_program=None):
647678
"""
679+
**Deprecated**
680+
648681
Get startup program for current parameter server.
649682
Modify operator input variables if there are variables that
650683
were split to several blocks.
651684
652685
Args:
653686
endpoint (str): current pserver endpoint.
654-
pserver_program (Program): call get_pserver_program first and
655-
pass the result here.
656-
startup_program (Program): if pass None, will use
657-
default_startup_program
687+
pserver_program (Program): deprecated, call get_pserver_program first.
688+
startup_program (Program): deprecated, should pass startup_program
689+
when initalizing
658690
659691
Returns:
660692
Program: parameter server side startup program.
661693
"""
694+
sys.stderr.write("get_startup_program() is deprecated, call\
695+
get_pserver_programs() to get pserver main and startup\
696+
in a single call.")
697+
if pserver_program != None:
698+
sys.stderr.write("passing pserver_program to get_startup_program()\
699+
is deprecated, you can use new API get_pserver_programs() to\
700+
get both pserver main program and startup program.")
701+
if startup_program != None:
702+
sys.stderr.write("passing startup_program to get_startup_program()\
703+
is deprecated, use fluid.program_guard() or pass this argument\
704+
to transpile() call.")
705+
662706
s_prog = Program()
663-
if not startup_program:
664-
orig_s_prog = default_startup_program()
665-
else:
666-
orig_s_prog = startup_program
707+
orig_s_prog = self.startup_program
667708
s_prog.random_seed = orig_s_prog.random_seed
668709
params = self.param_grad_ep_mapping[endpoint]["params"]
669710

0 commit comments

Comments
 (0)