@@ -326,8 +326,8 @@ class Function:
326326 def __init__ (
327327 self ,
328328 vm : "VM" ,
329- input_storage ,
330- output_storage ,
329+ input_storage : list [ Container ] ,
330+ output_storage : list [ Container ] ,
331331 indices ,
332332 outputs ,
333333 defaults ,
@@ -388,6 +388,11 @@ def __init__(
388388 self .nodes_with_inner_function = []
389389 self .output_keys = output_keys
390390
391+ if self .output_keys is not None :
392+ warnings .warn (FutureWarning , "output_keys is deprecated." )
393+
394+ assert len (self .output_storage ) == len (self .maker .fgraph .outputs )
395+
391396 # See if we have any mutable / borrow inputs
392397 # TODO: this only need to be set if there is more than one input
393398 self ._check_for_aliased_inputs = False
@@ -408,11 +413,6 @@ def __init__(
408413 finder = {}
409414 inv_finder = {}
410415
411- def distribute (indices , cs , value ):
412- input .distribute (value , indices , cs )
413- for c in cs :
414- c .provided += 1
415-
416416 # Store the list of names of named inputs.
417417 named_inputs = []
418418 # Count the number of un-named inputs.
@@ -777,6 +777,13 @@ def checkSV(sv_ori, sv_rpl):
777777 f_cpy .maker .fgraph .name = name
778778 return f_cpy
779779
780+ def _restore_defaults (self ):
781+ for i , (required , refeed , value ) in enumerate (self .defaults ):
782+ if refeed :
783+ if isinstance (value , Container ):
784+ value = value .storage [0 ]
785+ self [i ] = value
786+
780787 def __call__ (self , * args , ** kwargs ):
781788 """
782789 Evaluates value of a function on given arguments.
@@ -806,51 +813,45 @@ def __call__(self, *args, **kwargs):
806813 if ``output_subset`` is not passed.
807814 """
808815
809- def restore_defaults ():
810- for i , (required , refeed , value ) in enumerate (self .defaults ):
811- if refeed :
812- if isinstance (value , Container ):
813- value = value .storage [0 ]
814- self [i ] = value
815-
816816 profile = self .profile
817- t0 = time .perf_counter ()
817+ if profile is not None :
818+ t0 = time .perf_counter ()
818819
819820 output_subset = kwargs .pop ("output_subset" , None )
820- if output_subset is not None and self .output_keys is not None :
821- output_subset = [self .output_keys .index (key ) for key in output_subset ]
821+ if output_subset is not None :
822+ warnings .warn (FutureWarning , "output_subset is deprecated." )
823+ if self .output_keys is not None :
824+ output_subset = [self .output_keys .index (key ) for key in output_subset ]
822825
823826 # Reinitialize each container's 'provided' counter
824827 if self .trust_input :
825- i = 0
826- for arg in args :
827- s = self .input_storage [i ]
828- s .storage [0 ] = arg
829- i += 1
828+ for arg_container , arg in zip (self .input_storage , args , strict = False ):
829+ arg_container .storage [0 ] = arg
830830 else :
831- for c in self .input_storage :
832- c .provided = 0
831+ for arg_container in self .input_storage :
832+ arg_container .provided = 0
833833
834834 if len (args ) + len (kwargs ) > len (self .input_storage ):
835835 raise TypeError ("Too many parameter passed to pytensor function" )
836836
837837 # Set positional arguments
838- i = 0
839- for arg in args :
838+ for arg_container , arg in zip (self .input_storage , args , strict = False ):
840839 # TODO: provide a option for skipping the filter if we really
841840 # want speed.
842- s = self .input_storage [i ]
843841 # see this emails for a discuation about None as input
844842 # https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
845843 if arg is None :
846- s .storage [0 ] = arg
844+ arg_container .storage [0 ] = arg
847845 else :
848846 try :
849- s .storage [0 ] = s .type .filter (
850- arg , strict = s .strict , allow_downcast = s .allow_downcast
847+ arg_container .storage [0 ] = arg_container .type .filter (
848+ arg ,
849+ strict = arg_container .strict ,
850+ allow_downcast = arg_container .allow_downcast ,
851851 )
852852
853853 except Exception as e :
854+ i = self .input_storage .index (arg_container )
854855 function_name = "pytensor function"
855856 argument_name = "argument"
856857 if self .name :
@@ -875,27 +876,23 @@ def restore_defaults():
875876 + function_name
876877 + f" at index { int (i )} (0-based). { where } "
877878 ) + e .args
878- restore_defaults ()
879+ self . _restore_defaults ()
879880 raise
880- s .provided += 1
881- i += 1
881+ arg_container .provided += 1
882882
883883 # Set keyword arguments
884884 if kwargs : # for speed, skip the items for empty kwargs
885885 for k , arg in kwargs .items ():
886886 self [k ] = arg
887887
888- if (
889- not self .trust_input
890- and
891- # The getattr is only needed for old pickle
892- getattr (self , "_check_for_aliased_inputs" , True )
893- ):
888+ if not self .trust_input :
894889 # Collect aliased inputs among the storage space
895890 args_share_memory = []
896- for i in range (len (self .input_storage )):
897- i_var = self .maker .inputs [i ].variable
898- i_val = self .input_storage [i ].storage [0 ]
891+ for i , (inp , arg_storage ) in enumerate (
892+ zip (self .maker .inputs , self .input_storage )
893+ ):
894+ i_var = inp .variable
895+ i_val = arg_storage .storage [0 ]
899896 if hasattr (i_var .type , "may_share_memory" ):
900897 is_aliased = False
901898 for j in range (len (args_share_memory )):
@@ -932,36 +929,36 @@ def restore_defaults():
932929 self .input_storage [j ].storage [0 ]
933930 )
934931
935- # Check if inputs are missing, or if inputs were set more than once, or
936- # if we tried to provide inputs that are supposed to be implicit.
937- if not self .trust_input :
938- for c in self .input_storage :
939- if c .required and not c .provided :
940- restore_defaults ()
932+ # Check if inputs are missing, or if inputs were set more than once, or
933+ # if we tried to provide inputs that are supposed to be implicit.
934+ for arg_container in self .input_storage :
935+ if arg_container .required and not arg_container .provided :
936+ self ._restore_defaults ()
941937 raise TypeError (
942- f"Missing required input: { getattr (self .inv_finder [c ], 'variable' , self .inv_finder [c ])} "
938+ f"Missing required input: { getattr (self .inv_finder [arg_container ], 'variable' , self .inv_finder [arg_container ])} "
943939 )
944- if c .provided > 1 :
945- restore_defaults ()
940+ if arg_container .provided > 1 :
941+ self . _restore_defaults ()
946942 raise TypeError (
947- f"Multiple values for input: { getattr (self .inv_finder [c ], 'variable' , self .inv_finder [c ])} "
943+ f"Multiple values for input: { getattr (self .inv_finder [arg_container ], 'variable' , self .inv_finder [arg_container ])} "
948944 )
949- if c .implicit and c .provided > 0 :
950- restore_defaults ()
945+ if arg_container .implicit and arg_container .provided > 0 :
946+ self . _restore_defaults ()
951947 raise TypeError (
952- f"Tried to provide value for implicit input: { getattr (self .inv_finder [c ], 'variable' , self .inv_finder [c ])} "
948+ f"Tried to provide value for implicit input: { getattr (self .inv_finder [arg_container ], 'variable' , self .inv_finder [arg_container ])} "
953949 )
954950
955951 # Do the actual work
956- t0_fn = time .perf_counter ()
952+ if profile is not None :
953+ t0_fn = time .perf_counter ()
957954 try :
958955 outputs = (
959956 self .vm ()
960957 if output_subset is None
961958 else self .vm (output_subset = output_subset )
962959 )
963960 except Exception :
964- restore_defaults ()
961+ self . _restore_defaults ()
965962 if hasattr (self .vm , "position_of_error" ):
966963 # this is a new vm-provided function or c linker
967964 # they need this because the exception manipulation
@@ -979,9 +976,9 @@ def restore_defaults():
979976 # old-style linkers raise their own exceptions
980977 raise
981978
982- dt_fn = time . perf_counter () - t0_fn
983- self . maker . mode . fn_time += dt_fn
984- if profile :
979+ if profile is not None :
980+ dt_fn = time . perf_counter () - t0_fn
981+ self . maker . mode . fn_time += dt_fn
985982 profile .vm_call_time += dt_fn
986983
987984 # Retrieve the values that were computed
@@ -991,14 +988,13 @@ def restore_defaults():
991988
992989 # Remove internal references to required inputs.
993990 # These cannot be re-used anyway.
994- for c in self .input_storage :
995- if c .required :
996- c .storage [0 ] = None
991+ for arg_container in self .input_storage :
992+ if arg_container .required :
993+ arg_container .storage [0 ] = None
997994
998995 # if we are allowing garbage collection, remove the
999996 # output reference from the internal storage cells
1000997 if getattr (self .vm , "allow_gc" , False ):
1001- assert len (self .output_storage ) == len (self .maker .fgraph .outputs )
1002998 for o_container , o_variable in zip (
1003999 self .output_storage , self .maker .fgraph .outputs
10041000 ):
@@ -1020,17 +1016,12 @@ def restore_defaults():
10201016 outputs = outputs [: self .n_returned_outputs ]
10211017
10221018 # Put default values back in the storage
1023- restore_defaults ()
1024- #
1025- # NOTE: This logic needs to be replicated in
1026- # scan.
1027- # grep for 'PROFILE_CODE'
1028- #
1029-
1030- dt_call = time .perf_counter () - t0
1031- pytensor .compile .profiling .total_fct_exec_time += dt_call
1032- self .maker .mode .call_time += dt_call
1033- if profile :
1019+ self ._restore_defaults ()
1020+
1021+ if profile is not None :
1022+ dt_call = time .perf_counter () - t0
1023+ pytensor .compile .profiling .total_fct_exec_time += dt_call
1024+ self .maker .mode .call_time += dt_call
10341025 profile .fct_callcount += 1
10351026 profile .fct_call_time += dt_call
10361027 if hasattr (self .vm , "update_profile" ):
0 commit comments