20
20
import shutil
21
21
import six
22
22
23
+ from paddle .fluid .executor import Executor
23
24
from paddle .fluid .evaluator import Evaluator
24
25
from paddle .fluid .framework import Program , Parameter , default_main_program , default_startup_program , Variable
25
26
from . import core
@@ -587,8 +588,11 @@ def save_inference_model(dirname,
587
588
params_filename(str|None): The name of file to save all related parameters.
588
589
If it is setted None, parameters will be saved
589
590
in separate files .
590
- export_for_deployment(bool): remove the read ops that are added by py_reader
591
- for cpp inference lib. Default True
591
+ export_for_deployment(bool): If True, programs are modified to only support
592
+ direct inference deployment. Otherwise,
593
+ more information will be stored for flexible
594
+ optimization and re-training. Currently, only
595
+ True is supported.
592
596
593
597
Returns:
594
598
None
@@ -636,21 +640,28 @@ def save_inference_model(dirname,
636
640
if not os .path .isdir (dirname ):
637
641
os .makedirs (dirname )
638
642
639
- # Clear the is_target information and remove the existed feed and fetch op
640
- global_block = copy_program .global_block ()
641
- for i , op in enumerate (global_block .ops ):
642
- op .desc .set_is_target (False )
643
- if op .type == "feed" or op .type == "fetch" :
644
- global_block ._remove_op (i )
645
- copy_program .desc .flush ()
646
-
647
- pruned_program = copy_program ._prune (targets = target_vars )
648
- inference_program = pruned_program ._inference_optimize (
649
- export_for_deployment = export_for_deployment )
650
- fetch_var_names = [v .name for v in target_vars ]
651
-
652
- prepend_feed_ops (inference_program , feeded_var_names )
653
- append_fetch_ops (inference_program , fetch_var_names )
643
+ # When export_for_deployment is true, we modify the program online so that
644
+ # it can only be loaded for inference directly. If it's false, the whole
645
+ # original program and related meta are saved so that future usage can be
646
+ # more flexible.
647
+ if export_for_deployment :
648
+ global_block = copy_program .global_block ()
649
+ for i , op in enumerate (global_block .ops ):
650
+ op .desc .set_is_target (False )
651
+ if op .type == "feed" or op .type == "fetch" :
652
+ global_block ._remove_op (i )
653
+ copy_program .desc .flush ()
654
+
655
+ pruned_program = copy_program ._prune (targets = target_vars )
656
+ saved_program = pruned_program ._inference_optimize (prune_read_op = True )
657
+ fetch_var_names = [v .name for v in target_vars ]
658
+
659
+ prepend_feed_ops (saved_program , feeded_var_names )
660
+ append_fetch_ops (saved_program , fetch_var_names )
661
+ else :
662
+ # TODO(panyx0718): Save more information so that it can also be used
663
+ # for training and more flexible post-processing.
664
+ saved_program = copy_program
654
665
655
666
if model_filename is not None :
656
667
model_filename = os .path .basename (model_filename )
@@ -662,9 +673,9 @@ def save_inference_model(dirname,
662
673
params_filename = os .path .basename (params_filename )
663
674
664
675
with open (model_filename , "wb" ) as f :
665
- f .write (inference_program .desc .serialize_to_string ())
676
+ f .write (saved_program .desc .serialize_to_string ())
666
677
667
- save_persistables (executor , dirname , inference_program , params_filename )
678
+ save_persistables (executor , dirname , saved_program , params_filename )
668
679
669
680
# if there is lookup table, the trainer 0 will notify all pserver to save.
670
681
if main_program ._is_distributed and main_program ._is_chief and main_program ._distributed_lookup_table :
0 commit comments