Skip to content

Commit 10f7d00

Browse files
committed
To support full model saving.
In the future, we'd like to encourage user to save everything during training. This allows us to 1. Do more flexible optimization passes 2. Re-train and fune-tune
1 parent a69a584 commit 10f7d00

File tree

3 files changed

+36
-25
lines changed

3 files changed

+36
-25
lines changed

python/paddle/fluid/framework.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,7 +1647,7 @@ def clone(self, for_test=False):
16471647
The two code snippets above will generate same programs.
16481648
"""
16491649
if for_test:
1650-
p = self._inference_optimize(export_for_deployment=False)
1650+
p = self._inference_optimize(prune_read_op=False)
16511651
else:
16521652
p = Program()
16531653
p.current_block_idx = self.current_block_idx
@@ -1717,7 +1717,7 @@ def _prune(self, targets):
17171717
res._sync_with_cpp()
17181718
return res
17191719

1720-
def _inference_optimize(self, export_for_deployment=True):
1720+
def _inference_optimize(self, prune_read_op=True):
17211721
"""
17221722
This method will create a new program and do following adjustments on it:
17231723
1. Remove all reader variables and their creator ops if exist.
@@ -1729,8 +1729,8 @@ def _inference_optimize(self, export_for_deployment=True):
17291729
information will be lost.
17301730
17311731
Args:
1732-
export_for_deployment(bool): remove the read ops that are added by py_reader
1733-
for cpp inference library
1732+
prune_read_op(bool): remove the read ops that are added by py_reader
1733+
for cpp inference library
17341734
17351735
Notes: This API is a very low level API. Use
17361736
:code:`Program.clone(for_test=True)` instead.
@@ -1744,7 +1744,7 @@ def _inference_optimize(self, export_for_deployment=True):
17441744
# remove all readers and the read_op if exist
17451745
read_op_idx = 0
17461746
root_block = res.desc.block(0)
1747-
if export_for_deployment:
1747+
if prune_read_op:
17481748
while True:
17491749
if read_op_idx >= root_block.op_size() or root_block.op(
17501750
read_op_idx).type() == 'read':

python/paddle/fluid/io.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import shutil
2121
import six
2222

23+
from paddle.fluid.executor import Executor
2324
from paddle.fluid.evaluator import Evaluator
2425
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable
2526
from . import core
@@ -587,8 +588,11 @@ def save_inference_model(dirname,
587588
params_filename(str|None): The name of file to save all related parameters.
588589
If it is setted None, parameters will be saved
589590
in separate files .
590-
export_for_deployment(bool): remove the read ops that are added by py_reader
591-
for cpp inference lib. Default True
591+
export_for_deployment(bool): If True, programs are modified to only support
592+
direct inference deployment. Otherwise,
593+
more information will be stored for flexible
594+
optimization and re-training. Currently, only
595+
True is supported.
592596
593597
Returns:
594598
None
@@ -636,21 +640,28 @@ def save_inference_model(dirname,
636640
if not os.path.isdir(dirname):
637641
os.makedirs(dirname)
638642

639-
# Clear the is_target information and remove the existed feed and fetch op
640-
global_block = copy_program.global_block()
641-
for i, op in enumerate(global_block.ops):
642-
op.desc.set_is_target(False)
643-
if op.type == "feed" or op.type == "fetch":
644-
global_block._remove_op(i)
645-
copy_program.desc.flush()
646-
647-
pruned_program = copy_program._prune(targets=target_vars)
648-
inference_program = pruned_program._inference_optimize(
649-
export_for_deployment=export_for_deployment)
650-
fetch_var_names = [v.name for v in target_vars]
651-
652-
prepend_feed_ops(inference_program, feeded_var_names)
653-
append_fetch_ops(inference_program, fetch_var_names)
643+
# When export_for_deployment is true, we modify the program online so that
644+
# it can only be loaded for inference directly. If it's false, the whole
645+
# original program and related meta are saved so that future usage can be
646+
# more flexible.
647+
if export_for_deployment:
648+
global_block = copy_program.global_block()
649+
for i, op in enumerate(global_block.ops):
650+
op.desc.set_is_target(False)
651+
if op.type == "feed" or op.type == "fetch":
652+
global_block._remove_op(i)
653+
copy_program.desc.flush()
654+
655+
pruned_program = copy_program._prune(targets=target_vars)
656+
saved_program = pruned_program._inference_optimize(prune_read_op=True)
657+
fetch_var_names = [v.name for v in target_vars]
658+
659+
prepend_feed_ops(saved_program, feeded_var_names)
660+
append_fetch_ops(saved_program, fetch_var_names)
661+
else:
662+
# TODO(panyx0718): Save more information so that it can also be used
663+
# for training and more flexible post-processing.
664+
saved_program = copy_program
654665

655666
if model_filename is not None:
656667
model_filename = os.path.basename(model_filename)
@@ -662,9 +673,9 @@ def save_inference_model(dirname,
662673
params_filename = os.path.basename(params_filename)
663674

664675
with open(model_filename, "wb") as f:
665-
f.write(inference_program.desc.serialize_to_string())
676+
f.write(saved_program.desc.serialize_to_string())
666677

667-
save_persistables(executor, dirname, inference_program, params_filename)
678+
save_persistables(executor, dirname, saved_program, params_filename)
668679

669680
# if there is lookup table, the trainer 0 will notify all pserver to save.
670681
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:

python/paddle/fluid/tests/unittests/test_program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def net():
122122
net()
123123
no_read_program = main_program._inference_optimize()
124124
keep_read_program = main_program._inference_optimize(
125-
export_for_deployment=False)
125+
prune_read_op=False)
126126
no_read_ops = no_read_program.global_block().ops
127127
keep_read_ops = keep_read_program.global_block().ops
128128
self.assertEqual(len(keep_read_ops) - len(no_read_ops), 2)

0 commit comments

Comments
 (0)