@@ -637,8 +637,8 @@ def save_inference_model(dirname,
637
637
if isinstance (target_vars , Variable ):
638
638
target_vars = [target_vars ]
639
639
elif export_for_deployment :
640
- if not (bool (target_vars ) and all (
641
- isinstance (var , Variable ) for var in target_vars )):
640
+ if not (bool (target_vars ) and
641
+ all ( isinstance (var , Variable ) for var in target_vars )):
642
642
raise ValueError ("'target_vars' should be a list of Variable." )
643
643
644
644
if main_program is None :
@@ -667,10 +667,15 @@ def save_inference_model(dirname,
667
667
if export_for_deployment :
668
668
main_program = main_program .clone ()
669
669
global_block = main_program .global_block ()
670
+ need_to_remove_op_index = []
670
671
for i , op in enumerate (global_block .ops ):
671
672
op .desc .set_is_target (False )
672
673
if op .type == "feed" or op .type == "fetch" :
673
- global_block ._remove_op (i )
674
+ need_to_remove_op_index .append (i )
675
+
676
+ for index in need_to_remove_op_index [::- 1 ]:
677
+ global_block ._remove_op (index )
678
+
674
679
main_program .desc .flush ()
675
680
676
681
main_program = main_program ._prune (targets = target_vars )
0 commit comments