@@ -326,8 +326,8 @@ class Function:
326
326
def __init__ (
327
327
self ,
328
328
vm : "VM" ,
329
- input_storage ,
330
- output_storage ,
329
+ input_storage : list [ Container ] ,
330
+ output_storage : list [ Container ] ,
331
331
indices ,
332
332
outputs ,
333
333
defaults ,
@@ -388,6 +388,11 @@ def __init__(
388
388
self .nodes_with_inner_function = []
389
389
self .output_keys = output_keys
390
390
391
+ if self .output_keys is not None :
392
+ warnings .warn (FutureWarning , "output_keys is deprecated." )
393
+
394
+ assert len (self .output_storage ) == len (self .maker .fgraph .outputs )
395
+
391
396
# See if we have any mutable / borrow inputs
392
397
# TODO: this only need to be set if there is more than one input
393
398
self ._check_for_aliased_inputs = False
@@ -408,11 +413,6 @@ def __init__(
408
413
finder = {}
409
414
inv_finder = {}
410
415
411
- def distribute (indices , cs , value ):
412
- input .distribute (value , indices , cs )
413
- for c in cs :
414
- c .provided += 1
415
-
416
416
# Store the list of names of named inputs.
417
417
named_inputs = []
418
418
# Count the number of un-named inputs.
@@ -777,6 +777,13 @@ def checkSV(sv_ori, sv_rpl):
777
777
f_cpy .maker .fgraph .name = name
778
778
return f_cpy
779
779
780
+ def _restore_defaults (self ):
781
+ for i , (required , refeed , value ) in enumerate (self .defaults ):
782
+ if refeed :
783
+ if isinstance (value , Container ):
784
+ value = value .storage [0 ]
785
+ self [i ] = value
786
+
780
787
def __call__ (self , * args , ** kwargs ):
781
788
"""
782
789
Evaluates value of a function on given arguments.
@@ -806,51 +813,45 @@ def __call__(self, *args, **kwargs):
806
813
if ``output_subset`` is not passed.
807
814
"""
808
815
809
- def restore_defaults ():
810
- for i , (required , refeed , value ) in enumerate (self .defaults ):
811
- if refeed :
812
- if isinstance (value , Container ):
813
- value = value .storage [0 ]
814
- self [i ] = value
815
-
816
816
profile = self .profile
817
- t0 = time .perf_counter ()
817
+ if profile is not None :
818
+ t0 = time .perf_counter ()
818
819
819
820
output_subset = kwargs .pop ("output_subset" , None )
820
- if output_subset is not None and self .output_keys is not None :
821
- output_subset = [self .output_keys .index (key ) for key in output_subset ]
821
+ if output_subset is not None :
822
+ warnings .warn (FutureWarning , "output_subset is deprecated." )
823
+ if self .output_keys is not None :
824
+ output_subset = [self .output_keys .index (key ) for key in output_subset ]
822
825
823
826
# Reinitialize each container's 'provided' counter
824
827
if self .trust_input :
825
- i = 0
826
- for arg in args :
827
- s = self .input_storage [i ]
828
- s .storage [0 ] = arg
829
- i += 1
828
+ for arg_container , arg in zip (self .input_storage , args , strict = False ):
829
+ arg_container .storage [0 ] = arg
830
830
else :
831
- for c in self .input_storage :
832
- c .provided = 0
831
+ for arg_container in self .input_storage :
832
+ arg_container .provided = 0
833
833
834
834
if len (args ) + len (kwargs ) > len (self .input_storage ):
835
835
raise TypeError ("Too many parameter passed to pytensor function" )
836
836
837
837
# Set positional arguments
838
- i = 0
839
- for arg in args :
838
+ for arg_container , arg in zip (self .input_storage , args , strict = False ):
840
839
# TODO: provide a option for skipping the filter if we really
841
840
# want speed.
842
- s = self .input_storage [i ]
843
841
# see this emails for a discuation about None as input
844
842
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
845
843
if arg is None :
846
- s .storage [0 ] = arg
844
+ arg_container .storage [0 ] = arg
847
845
else :
848
846
try :
849
- s .storage [0 ] = s .type .filter (
850
- arg , strict = s .strict , allow_downcast = s .allow_downcast
847
+ arg_container .storage [0 ] = arg_container .type .filter (
848
+ arg ,
849
+ strict = arg_container .strict ,
850
+ allow_downcast = arg_container .allow_downcast ,
851
851
)
852
852
853
853
except Exception as e :
854
+ i = self .input_storage .index (arg_container )
854
855
function_name = "pytensor function"
855
856
argument_name = "argument"
856
857
if self .name :
@@ -875,27 +876,23 @@ def restore_defaults():
875
876
+ function_name
876
877
+ f" at index { int (i )} (0-based). { where } "
877
878
) + e .args
878
- restore_defaults ()
879
+ self . _restore_defaults ()
879
880
raise
880
- s .provided += 1
881
- i += 1
881
+ arg_container .provided += 1
882
882
883
883
# Set keyword arguments
884
884
if kwargs : # for speed, skip the items for empty kwargs
885
885
for k , arg in kwargs .items ():
886
886
self [k ] = arg
887
887
888
- if (
889
- not self .trust_input
890
- and
891
- # The getattr is only needed for old pickle
892
- getattr (self , "_check_for_aliased_inputs" , True )
893
- ):
888
+ if not self .trust_input :
894
889
# Collect aliased inputs among the storage space
895
890
args_share_memory = []
896
- for i in range (len (self .input_storage )):
897
- i_var = self .maker .inputs [i ].variable
898
- i_val = self .input_storage [i ].storage [0 ]
891
+ for i , (inp , arg_storage ) in enumerate (
892
+ zip (self .maker .inputs , self .input_storage )
893
+ ):
894
+ i_var = inp .variable
895
+ i_val = arg_storage .storage [0 ]
899
896
if hasattr (i_var .type , "may_share_memory" ):
900
897
is_aliased = False
901
898
for j in range (len (args_share_memory )):
@@ -932,36 +929,36 @@ def restore_defaults():
932
929
self .input_storage [j ].storage [0 ]
933
930
)
934
931
935
- # Check if inputs are missing, or if inputs were set more than once, or
936
- # if we tried to provide inputs that are supposed to be implicit.
937
- if not self .trust_input :
938
- for c in self .input_storage :
939
- if c .required and not c .provided :
940
- restore_defaults ()
932
+ # Check if inputs are missing, or if inputs were set more than once, or
933
+ # if we tried to provide inputs that are supposed to be implicit.
934
+ for arg_container in self .input_storage :
935
+ if arg_container .required and not arg_container .provided :
936
+ self ._restore_defaults ()
941
937
raise TypeError (
942
- f"Missing required input: { getattr (self .inv_finder [c ], 'variable' , self .inv_finder [c ])} "
938
+ f"Missing required input: { getattr (self .inv_finder [arg_container ], 'variable' , self .inv_finder [arg_container ])} "
943
939
)
944
- if c .provided > 1 :
945
- restore_defaults ()
940
+ if arg_container .provided > 1 :
941
+ self . _restore_defaults ()
946
942
raise TypeError (
947
- f"Multiple values for input: { getattr (self .inv_finder [c ], 'variable' , self .inv_finder [c ])} "
943
+ f"Multiple values for input: { getattr (self .inv_finder [arg_container ], 'variable' , self .inv_finder [arg_container ])} "
948
944
)
949
- if c .implicit and c .provided > 0 :
950
- restore_defaults ()
945
+ if arg_container .implicit and arg_container .provided > 0 :
946
+ self . _restore_defaults ()
951
947
raise TypeError (
952
- f"Tried to provide value for implicit input: { getattr (self .inv_finder [c ], 'variable' , self .inv_finder [c ])} "
948
+ f"Tried to provide value for implicit input: { getattr (self .inv_finder [arg_container ], 'variable' , self .inv_finder [arg_container ])} "
953
949
)
954
950
955
951
# Do the actual work
956
- t0_fn = time .perf_counter ()
952
+ if profile is not None :
953
+ t0_fn = time .perf_counter ()
957
954
try :
958
955
outputs = (
959
956
self .vm ()
960
957
if output_subset is None
961
958
else self .vm (output_subset = output_subset )
962
959
)
963
960
except Exception :
964
- restore_defaults ()
961
+ self . _restore_defaults ()
965
962
if hasattr (self .vm , "position_of_error" ):
966
963
# this is a new vm-provided function or c linker
967
964
# they need this because the exception manipulation
@@ -979,9 +976,9 @@ def restore_defaults():
979
976
# old-style linkers raise their own exceptions
980
977
raise
981
978
982
- dt_fn = time . perf_counter () - t0_fn
983
- self . maker . mode . fn_time += dt_fn
984
- if profile :
979
+ if profile is not None :
980
+ dt_fn = time . perf_counter () - t0_fn
981
+ self . maker . mode . fn_time += dt_fn
985
982
profile .vm_call_time += dt_fn
986
983
987
984
# Retrieve the values that were computed
@@ -991,14 +988,13 @@ def restore_defaults():
991
988
992
989
# Remove internal references to required inputs.
993
990
# These cannot be re-used anyway.
994
- for c in self .input_storage :
995
- if c .required :
996
- c .storage [0 ] = None
991
+ for arg_container in self .input_storage :
992
+ if arg_container .required :
993
+ arg_container .storage [0 ] = None
997
994
998
995
# if we are allowing garbage collection, remove the
999
996
# output reference from the internal storage cells
1000
997
if getattr (self .vm , "allow_gc" , False ):
1001
- assert len (self .output_storage ) == len (self .maker .fgraph .outputs )
1002
998
for o_container , o_variable in zip (
1003
999
self .output_storage , self .maker .fgraph .outputs
1004
1000
):
@@ -1020,17 +1016,12 @@ def restore_defaults():
1020
1016
outputs = outputs [: self .n_returned_outputs ]
1021
1017
1022
1018
# Put default values back in the storage
1023
- restore_defaults ()
1024
- #
1025
- # NOTE: This logic needs to be replicated in
1026
- # scan.
1027
- # grep for 'PROFILE_CODE'
1028
- #
1029
-
1030
- dt_call = time .perf_counter () - t0
1031
- pytensor .compile .profiling .total_fct_exec_time += dt_call
1032
- self .maker .mode .call_time += dt_call
1033
- if profile :
1019
+ self ._restore_defaults ()
1020
+
1021
+ if profile is not None :
1022
+ dt_call = time .perf_counter () - t0
1023
+ pytensor .compile .profiling .total_fct_exec_time += dt_call
1024
+ self .maker .mode .call_time += dt_call
1034
1025
profile .fct_callcount += 1
1035
1026
profile .fct_call_time += dt_call
1036
1027
if hasattr (self .vm , "update_profile" ):
0 commit comments