Skip to content

Commit 103deb1

Browse files
authored
Merge pull request #13484 from panyx0718/ir4
To support full model saving.
2 parents bc5fc5c + 10f7d00 commit 103deb1

File tree

3 files changed

+36
-25
lines changed

3 files changed

+36
-25
lines changed

python/paddle/fluid/framework.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,7 +1647,7 @@ def clone(self, for_test=False):
16471647
The two code snippets above will generate same programs.
16481648
"""
16491649
if for_test:
1650-
p = self._inference_optimize(export_for_deployment=False)
1650+
p = self._inference_optimize(prune_read_op=False)
16511651
else:
16521652
p = Program()
16531653
p.current_block_idx = self.current_block_idx
@@ -1717,7 +1717,7 @@ def _prune(self, targets):
17171717
res._sync_with_cpp()
17181718
return res
17191719

1720-
def _inference_optimize(self, export_for_deployment=True):
1720+
def _inference_optimize(self, prune_read_op=True):
17211721
"""
17221722
This method will create a new program and do following adjustments on it:
17231723
1. Remove all reader variables and their creator ops if exist.
@@ -1729,8 +1729,8 @@ def _inference_optimize(self, export_for_deployment=True):
17291729
information will be lost.
17301730
17311731
Args:
1732-
export_for_deployment(bool): remove the read ops that are added by py_reader
1733-
for cpp inference library
1732+
prune_read_op(bool): remove the read ops that are added by py_reader
1733+
for cpp inference library
17341734
17351735
Notes: This API is a very low level API. Use
17361736
:code:`Program.clone(for_test=True)` instead.
@@ -1744,7 +1744,7 @@ def _inference_optimize(self, export_for_deployment=True):
17441744
# remove all readers and the read_op if exist
17451745
read_op_idx = 0
17461746
root_block = res.desc.block(0)
1747-
if export_for_deployment:
1747+
if prune_read_op:
17481748
while True:
17491749
if read_op_idx >= root_block.op_size() or root_block.op(
17501750
read_op_idx).type() == 'read':

python/paddle/fluid/io.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import shutil
2121
import six
2222

23+
from paddle.fluid.executor import Executor
2324
from paddle.fluid.evaluator import Evaluator
2425
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable
2526
from . import core
@@ -587,8 +588,11 @@ def save_inference_model(dirname,
587588
params_filename(str|None): The name of file to save all related parameters.
588589
If it is setted None, parameters will be saved
589590
in separate files .
590-
export_for_deployment(bool): remove the read ops that are added by py_reader
591-
for cpp inference lib. Default True
591+
export_for_deployment(bool): If True, programs are modified to only support
592+
direct inference deployment. Otherwise,
593+
more information will be stored for flexible
594+
optimization and re-training. Currently, only
595+
True is supported.
592596
593597
Returns:
594598
None
@@ -636,21 +640,28 @@ def save_inference_model(dirname,
636640
if not os.path.isdir(dirname):
637641
os.makedirs(dirname)
638642

639-
# Clear the is_target information and remove the existed feed and fetch op
640-
global_block = copy_program.global_block()
641-
for i, op in enumerate(global_block.ops):
642-
op.desc.set_is_target(False)
643-
if op.type == "feed" or op.type == "fetch":
644-
global_block._remove_op(i)
645-
copy_program.desc.flush()
646-
647-
pruned_program = copy_program._prune(targets=target_vars)
648-
inference_program = pruned_program._inference_optimize(
649-
export_for_deployment=export_for_deployment)
650-
fetch_var_names = [v.name for v in target_vars]
651-
652-
prepend_feed_ops(inference_program, feeded_var_names)
653-
append_fetch_ops(inference_program, fetch_var_names)
643+
# When export_for_deployment is true, we modify the program online so that
644+
# it can only be loaded for inference directly. If it's false, the whole
645+
# original program and related meta are saved so that future usage can be
646+
# more flexible.
647+
if export_for_deployment:
648+
global_block = copy_program.global_block()
649+
for i, op in enumerate(global_block.ops):
650+
op.desc.set_is_target(False)
651+
if op.type == "feed" or op.type == "fetch":
652+
global_block._remove_op(i)
653+
copy_program.desc.flush()
654+
655+
pruned_program = copy_program._prune(targets=target_vars)
656+
saved_program = pruned_program._inference_optimize(prune_read_op=True)
657+
fetch_var_names = [v.name for v in target_vars]
658+
659+
prepend_feed_ops(saved_program, feeded_var_names)
660+
append_fetch_ops(saved_program, fetch_var_names)
661+
else:
662+
# TODO(panyx0718): Save more information so that it can also be used
663+
# for training and more flexible post-processing.
664+
saved_program = copy_program
654665

655666
if model_filename is not None:
656667
model_filename = os.path.basename(model_filename)
@@ -662,9 +673,9 @@ def save_inference_model(dirname,
662673
params_filename = os.path.basename(params_filename)
663674

664675
with open(model_filename, "wb") as f:
665-
f.write(inference_program.desc.serialize_to_string())
676+
f.write(saved_program.desc.serialize_to_string())
666677

667-
save_persistables(executor, dirname, inference_program, params_filename)
678+
save_persistables(executor, dirname, saved_program, params_filename)
668679

669680
# if there is lookup table, the trainer 0 will notify all pserver to save.
670681
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:

python/paddle/fluid/tests/unittests/test_program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def net():
122122
net()
123123
no_read_program = main_program._inference_optimize()
124124
keep_read_program = main_program._inference_optimize(
125-
export_for_deployment=False)
125+
prune_read_op=False)
126126
no_read_ops = no_read_program.global_block().ops
127127
keep_read_ops = keep_read_program.global_block().ops
128128
self.assertEqual(len(keep_read_ops) - len(no_read_ops), 2)

0 commit comments

Comments
 (0)