@@ -682,15 +682,19 @@ def _get_varname_from_block(block):
682
682
)
683
683
684
684
685
- def _get_program_cache_key (feed , fetch_list ):
685
+ def _get_feed_fetch_var_names (feed , fetch_list ):
686
686
feed_var_names = []
687
687
if isinstance (feed , dict ):
688
688
feed_var_names = list (feed .keys ())
689
689
elif isinstance (feed , (list , tuple )):
690
690
for i , each in enumerate (feed ):
691
691
feed_var_names += list (each .keys ())
692
692
fetch_var_names = list (map (_to_name_str , fetch_list ))
693
- return str (feed_var_names + fetch_var_names )
693
+ return feed_var_names + fetch_var_names
694
+
695
+
696
+ def _get_program_cache_key (feed , fetch_list ):
697
+ return str (_get_feed_fetch_var_names (feed , fetch_list ))
694
698
695
699
696
700
def _as_lodtensor (data , place , dtype = None ):
@@ -1026,7 +1030,7 @@ def _get_program_and_executor(self, cached_data):
1026
1030
1027
1031
if enable_inplace or enable_addto :
1028
1032
# inplace should skip feed and fetch var
1029
- skip_var_names = eval ( _get_program_cache_key ( feed , fetch_list ) )
1033
+ skip_var_names = _get_feed_fetch_var_names ( feed , fetch_list )
1030
1034
_apply_inplace_addto_pass (
1031
1035
program , enable_inplace , enable_addto , skip_var_names
1032
1036
)
0 commit comments