Skip to content

Commit 10ececb

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_sequence_pad_2
2 parents 802b334 + 8ea4218 commit 10ececb

File tree

2 files changed

+230
-70
lines changed

2 files changed

+230
-70
lines changed

paddle/fluid/API.spec

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,9 +393,9 @@ paddle.fluid.contrib.MagnitudePruner.__init__ (ArgSpec(args=['self', 'threshold'
393393
paddle.fluid.contrib.MagnitudePruner.prune (ArgSpec(args=['self', 'param', 'threshold'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
394394
paddle.fluid.contrib.RatioPruner.__init__ (ArgSpec(args=['self', 'ratios'], varargs=None, keywords=None, defaults=(None,)), ('document', 'e7a81a325b296a9ca502ee5adb4fc85d'))
395395
paddle.fluid.contrib.RatioPruner.prune (ArgSpec(args=['self', 'param', 'ratio'], varargs=None, keywords=None, defaults=(None,)), ('document', '358cbf2978c91028fb96a195a9884645'))
396-
paddle.fluid.contrib.load_persistables_for_increment (ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var', 'lookup_table_var_path'], varargs=None, keywords=None, defaults=None), ('document', '11fbf7e8dd2289805de291b453a33ee7'))
397-
paddle.fluid.contrib.load_persistables_for_inference (ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var_name'], varargs=None, keywords=None, defaults=None), ('document', '5b5577bb3d24070da819674255d16196'))
398-
paddle.fluid.contrib.convert_dist_to_sparse_program (ArgSpec(args=['program'], varargs=None, keywords=None, defaults=None), ('document', '4efbd93876832d4d35497cdbc7a1e6d8'))
396+
paddle.fluid.contrib.load_persistables_for_increment (ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var', 'lookup_table_var_path'], varargs=None, keywords=None, defaults=None), ('document', '2ab36d4f7a564f5f65e455807ad06c67'))
397+
paddle.fluid.contrib.load_persistables_for_inference (ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var_name'], varargs=None, keywords=None, defaults=None), ('document', '59066bac9db0ac6ce414d05780b7333f'))
398+
paddle.fluid.contrib.convert_dist_to_sparse_program (ArgSpec(args=['program'], varargs=None, keywords=None, defaults=None), ('document', '74c39c595dc70d6be2f16d8e462d282b'))
399399
paddle.fluid.contrib.HDFSClient.__init__ (ArgSpec(args=['self', 'hadoop_home', 'configs'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
400400
paddle.fluid.contrib.HDFSClient.delete (ArgSpec(args=['self', 'hdfs_path'], varargs=None, keywords=None, defaults=None), ('document', 'c3721aa2d4d9ef5a857dd47b2681c03e'))
401401
paddle.fluid.contrib.HDFSClient.download (ArgSpec(args=['self', 'hdfs_path', 'local_path', 'overwrite', 'unzip'], varargs=None, keywords=None, defaults=(False, False)), ('document', 'ca55bde92184d3fd0f9f5c963b25e634'))

python/paddle/fluid/contrib/utils/lookup_table_utils.py

Lines changed: 227 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import time
1919
import logging
2020

21+
import paddle
2122
from paddle.fluid import core
2223
from paddle.fluid import io
2324
from paddle.fluid import Program
@@ -84,8 +85,9 @@ def convert_dist_to_sparse_program(program):
8485
when we train model with distributed lookup table but want to do the local inference, we can use
8586
this function to convert the train program with distributed lookup table to sparse lookup table.
8687
87-
:param program(Program): the program must be the trainer program, which will be get by the distribute transpiler.
88-
:return:
88+
Args:
89+
program(Program): the program must be the trainer program, which will be get by the distribute transpiler.
90+
Returns:
8991
program: The `program` is a Program, it's the program replace distributed lookup table to sparse lookup table.
9092
"""
9193
if not program._distributed_lookup_table:
@@ -128,68 +130,92 @@ def convert_dist_to_sparse_program(program):
128130
return program
129131

130132

131-
def _load_persistable_vars(executor, dirname, program, lookup_table_vars):
132-
def _is_checkpoint_var(exclude_fluid_vars=None):
133-
"""
134-
the checkpoint will not save or load all the variables.
135-
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
136-
137-
: param var(Variable)
138-
"""
139-
140-
if exclude_fluid_vars is None:
141-
exclude_fluid_vars = []
142-
143-
def is_valid(var):
144-
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
145-
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
146-
var.desc.type() == core.VarDesc.VarType.RAW:
147-
return False
148-
# @GRAD are named for gradient variables, checkpoint will not save it.
149-
if "@GRAD" in var.name:
150-
return False
151-
# .trainer_ are named for distribute train variables, checkpoint will not save it.
152-
if ".trainer_" in var.name:
153-
return False
154-
155-
# .block is named for distribute train variables, checkpoint will not save it.
156-
if ".block" in var.name:
157-
return False
158-
159-
if "tmp_" in var.name:
160-
return False
161-
162-
if var.name in exclude_fluid_vars:
163-
return False
164-
165-
return var.persistable
166-
167-
return is_valid
168-
169-
io.load_vars(
170-
executor,
171-
dirname=dirname,
172-
main_program=program,
173-
predicate=_is_checkpoint_var(lookup_table_vars),
174-
filename=None)
175-
176-
177133
def load_persistables_for_increment(dirname, executor, program,
178134
lookup_table_var, lookup_table_var_path):
179135
"""
180136
WARNING: this function will only be used for distributed training with distributed lookup table.
181137
for increment trainning, the pserver will not only load dense variables,
182-
but also load the suitable lookup table var. Because of slice lookup table
183-
var with HASH, we must load the correct slice var.
138+
but also load the suitable lookup table var. Because of sliced lookup table
139+
var with HASH, we must load the correct sliced var.
140+
141+
Args:
142+
dirname(str): The directory path
143+
executor(Executor): The executor to run for loading inference model.
144+
program(Program): The parameter server program, which will run on Pserver.
145+
lookup_table_var: the distributed lookup tables var name.
146+
lookup_table_var_path: the the distributed lookup tables var location.
147+
148+
Returns:
149+
None
150+
"""
184151

152+
def _load_persistable_vars(executor, dirname, need_load_vars):
153+
load_prog = Program()
154+
load_block = load_prog.global_block()
155+
need_delete_vars = []
156+
157+
for param in need_load_vars:
158+
origin_var = param.origin
159+
slice_var = param.slice
160+
is_slice = param.is_slice
161+
offset = param.offset
162+
163+
if is_slice:
164+
origin = load_block.create_var(
165+
name="{}.load".format(origin_var.name),
166+
type=origin_var.type,
167+
shape=origin_var.shape,
168+
dtype=origin_var.dtype,
169+
persistable=True)
170+
171+
load_block.append_op(
172+
type='load',
173+
inputs={},
174+
outputs={'Out': [origin]},
175+
attrs={
176+
'file_path': os.path.join(dirname, origin_var.name)
177+
})
178+
179+
slice = load_block.create_var(
180+
name=slice_var.name,
181+
type=slice_var.type,
182+
shape=slice_var.shape,
183+
dtype=slice_var.dtype,
184+
persistable=True)
185+
186+
dim1_flatten = reduce(lambda x, y: x * y, slice.shape[1:])
187+
start = int(offset / dim1_flatten)
188+
end = int(offset / dim1_flatten + slice.shape[0])
189+
190+
load_block.append_op(
191+
type="slice",
192+
inputs={'Input': origin},
193+
outputs={'Out': slice},
194+
attrs={'axes': [0],
195+
'starts': [start],
196+
'ends': [end]})
197+
198+
need_delete_vars.append(origin)
199+
else:
200+
origin = load_block.create_var(
201+
name="{}".format(origin_var.name),
202+
type=origin_var.type,
203+
shape=origin_var.shape,
204+
dtype=origin_var.dtype,
205+
persistable=True)
206+
load_block.append_op(
207+
type='load',
208+
inputs={},
209+
outputs={'Out': [origin]},
210+
attrs={
211+
'file_path': os.path.join(dirname, origin_var.name)
212+
})
185213

186-
:param dirname(str): The directory path
187-
:param executor(Executor): The executor to run for loading inference model.
188-
:param program(Program): The parameter server program, which will run on Pserver.
189-
:param lookup_table_var: the distributed lookup tables var name.
190-
:param lookup_table_var_path: the the distributed lookup tables var location.
191-
:return: None
192-
"""
214+
load_block.append_op(
215+
type='delete_var',
216+
inputs={'X': need_delete_vars}, )
217+
218+
executor.run(load_prog)
193219

194220
def __load_lookup_table_vars(executor, main_program, lookup_table_var,
195221
lookup_table_var_path):
@@ -217,7 +243,9 @@ def __load_lookup_table_vars(executor, main_program, lookup_table_var,
217243
"Distributed Lookup Table Vars from {}, time = {}".format(
218244
dirname, time.ctime()))
219245

220-
_load_persistable_vars(executor, dirname, program, [lookup_table_var])
246+
need_load_vars = program._parameters_on_pservers.get_distributed_vars_by_ep(
247+
program._ps_endpoint)
248+
_load_persistable_vars(executor, dirname, need_load_vars)
221249
__load_lookup_table_vars(executor, program, lookup_table_var,
222250
lookup_table_var_path)
223251

@@ -233,15 +261,62 @@ def load_persistables_for_inference(dirname, executor, program,
233261
Inference with distributed lookup table is a little funky, this function will load distributed
234262
lookup table vars into sparse var, can be used in local inference mode.
235263
236-
:param dirname(str): The directory path
237-
:param executor(Executor): The executor to run for loading inference model.
238-
:param program(Program): The parameter server program, which will run on Pserver.
239-
:param lookup_table_var_name: the distributed lookup tables var name.
240-
:return: None
264+
Args:
265+
dirname(str): The directory path
266+
executor(Executor): The executor to run for loading inference model.
267+
program(Program): The parameter server program, which will run on Pserver.
268+
lookup_table_var_name: the distributed lookup tables var name.
269+
Returns:
270+
None
241271
"""
242272

243-
def __load_lookup_table_vars(executor, dirname, main_program,
244-
lookup_table_vars):
273+
def _load_persistable_vars(executor, dirname, program, lookup_table_vars):
274+
def _is_checkpoint_var(exclude_fluid_vars=None):
275+
"""
276+
the checkpoint will not save or load all the variables.
277+
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
278+
279+
: param var(Variable)
280+
"""
281+
282+
if exclude_fluid_vars is None:
283+
exclude_fluid_vars = []
284+
285+
def is_valid(var):
286+
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
287+
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
288+
var.desc.type() == core.VarDesc.VarType.RAW:
289+
return False
290+
# @GRAD are named for gradient variables, checkpoint will not save it.
291+
if "@GRAD" in var.name:
292+
return False
293+
# .trainer_ are named for distribute train variables, checkpoint will not save it.
294+
if ".trainer_" in var.name:
295+
return False
296+
297+
# .block is named for distribute train variables, checkpoint will not save it.
298+
if ".block" in var.name:
299+
return False
300+
301+
if "tmp_" in var.name:
302+
return False
303+
304+
if var.name in exclude_fluid_vars:
305+
return False
306+
307+
return var.persistable
308+
309+
return is_valid
310+
311+
io.load_vars(
312+
executor,
313+
dirname=dirname,
314+
main_program=program,
315+
predicate=_is_checkpoint_var(lookup_table_vars),
316+
filename=None)
317+
318+
def _load_lookup_table_vars(executor, dirname, main_program,
319+
lookup_table_vars):
245320
if not os.path.isdir(dirname):
246321
raise ValueError("There is no directory named '%s'", dirname)
247322

@@ -313,11 +388,96 @@ def __load_lookup_table_vars(executor, dirname, main_program,
313388
dirname, time.ctime()))
314389

315390
_load_persistable_vars(executor, dirname, program, [lookup_table_var_name])
316-
__load_lookup_table_vars(executor, dirname, program,
317-
[lookup_table_var_name])
391+
_load_lookup_table_vars(executor, dirname, program, [lookup_table_var_name])
318392

319393
_logger.info("Finish Load Sparse Program With "
320394
"Distributed Lookup Table Vars from {}, time = {}".format(
321395
dirname, time.ctime()))
322396

323397
return program
398+
399+
400+
def get_inference_model(main_program, feeded_var_names, target_vars):
401+
"""
402+
Prune the given `main_program` to build a new program especially for inference with distributed lookup table ,
403+
and then add `feeded_vars` and `target_vars` in this program.
404+
405+
Args:
406+
main_program(Program|None): The original program, which will be pruned to
407+
build the inference model. If is setted None,
408+
the default main program will be used.
409+
Default: None.
410+
feeded_var_names(list[str]): Names of variables that need to be feeded data
411+
during inference.
412+
target_vars(list[Variable]): Variables from which we can get inference
413+
results.
414+
Returns:
415+
program(Program)
416+
417+
Raises:
418+
ValueError: If `feed_var_names` is not a list of basestring.
419+
ValueError: If `target_vars` is not a list of Variable.
420+
421+
"""
422+
423+
def prepend_feed_ops(inference_program,
424+
feed_target_names,
425+
feed_holder_name='feed'):
426+
if len(feed_target_names) == 0:
427+
return
428+
429+
global_block = inference_program.global_block()
430+
431+
feed_var = global_block.create_var(
432+
name=feed_holder_name,
433+
type=core.VarDesc.VarType.FEED_MINIBATCH,
434+
persistable=True)
435+
436+
for i, name in enumerate(feed_target_names):
437+
out = global_block.var(name)
438+
global_block._prepend_op(
439+
type='feed',
440+
inputs={'X': [feed_var]},
441+
outputs={'Out': [out]},
442+
attrs={'col': i})
443+
444+
def append_fetch_ops(inference_program,
445+
fetch_target_names,
446+
fetch_holder_name='fetch'):
447+
global_block = inference_program.global_block()
448+
fetch_var = global_block.create_var(
449+
name=fetch_holder_name,
450+
type=core.VarDesc.VarType.FETCH_LIST,
451+
persistable=True)
452+
453+
for i, name in enumerate(fetch_target_names):
454+
global_block.append_op(
455+
type='fetch',
456+
inputs={'X': [name]},
457+
outputs={'Out': [fetch_var]},
458+
attrs={'col': i})
459+
460+
origin_program = main_program.clone()
461+
main_program = main_program.clone()
462+
global_block = main_program.global_block()
463+
464+
need_to_remove_op_index = []
465+
for i, op in enumerate(global_block.ops):
466+
op.desc.set_is_target(False)
467+
if op.type == "feed" or op.type == "fetch":
468+
need_to_remove_op_index.append(i)
469+
470+
for index in need_to_remove_op_index[::-1]:
471+
global_block._remove_op(index)
472+
473+
main_program.desc.flush()
474+
475+
main_program = main_program._prune(targets=target_vars)
476+
main_program = main_program._inference_optimize(prune_read_op=True)
477+
478+
fetch_var_names = [v.name for v in target_vars]
479+
480+
prepend_feed_ops(main_program, feeded_var_names)
481+
append_fetch_ops(main_program, fetch_var_names)
482+
483+
return main_program

0 commit comments

Comments
 (0)