@@ -540,15 +540,42 @@ def __contains__(self, item):
540540 self ._value = ValueAttribute ()
541541 self ._container = ContainerAttribute ()
542542
543- # TODO: Get rid of all this `expanded_inputs` nonsense
544- assert len (self .maker .expanded_inputs ) == len (self .input_storage )
543+ update_storage = [
544+ container
545+ for inp , container in zip (
546+ self .maker .expanded_inputs , input_storage , strict = True
547+ )
548+ if inp .update is not None
549+ ]
550+ # Updates are the last inner outputs that are not returned by Function.__call__
551+ self .n_returned_outputs = len (self .output_storage ) - len (update_storage )
552+
553+ # Function.__call__ is responsible for updating the inputs, unless the vm promises to do it itself
554+ self .update_input_storage : tuple [int , Container ] = ()
555+ if getattr (vm , "need_update_inputs" , True ):
556+ self .update_input_storage = tuple (
557+ zip (
558+ range (self .n_returned_outputs , len (output_storage )),
559+ update_storage ,
560+ strict = True ,
561+ )
562+ )
545563
546- # This is used only when `vm.need_update_inputs` is `False`, because
547- # we're using one of the VM objects and it is putting updates back into
548- # the input containers all by itself.
549- self .n_returned_outputs = len (self .output_storage ) - sum (
550- inp .update is not None for inp in self .maker .expanded_inputs
551- )
564+ # In every function call we place inputs in the input_storage, and the vm places outputs in the output_storage
565+ # After the call, we want to erase (some of) these references, to allow Python to GC them if unused
566+ # Required input containers are the non-default inputs, must always be provided again, so we GC them
567+ self .clear_input_output_storage_data = [
568+ container .storage for container in input_storage if container .required
569+ ]
570+ if getattr (vm , "allow_gc" , False ):
571+ # If the vm allows us to GC the outputs, do it
572+ self .clear_input_output_storage_data += [
573+ container .storage
574+ for container , variable in zip (
575+ self .output_storage , self .maker .fgraph .outputs , strict = True
576+ )
577+ if variable .owner is not None # Not a constant output
578+ ]
552579
553580 for node in self .maker .fgraph .apply_nodes :
554581 if isinstance (node .op , HasInnerGraph ):
@@ -747,7 +774,7 @@ def checkSV(sv_ori, sv_rpl):
747774 elif isinstance (profile , str ):
748775 profile = pytensor .compile .profiling .ProfileStats (message = profile )
749776
750- f_cpy = maker . __class__ (
777+ f_cpy = type ( maker ) (
751778 inputs = ins ,
752779 outputs = outs ,
753780 fgraph = fg_cpy ,
@@ -765,6 +792,8 @@ def checkSV(sv_ori, sv_rpl):
765792 # check that.
766793 accept_inplace = True ,
767794 no_fgraph_prep = True ,
795+ output_keys = maker .output_keys ,
796+ name = name ,
768797 ).create (input_storage , storage_map = new_storage_map )
769798
770799 for in_ori , in_cpy , ori , cpy in zip (
@@ -796,9 +825,6 @@ def checkSV(sv_ori, sv_rpl):
796825 in_cpy .variable = swap [in_ori .variable ]
797826
798827 f_cpy .trust_input = self .trust_input
799- f_cpy .unpack_single = self .unpack_single
800- f_cpy .name = name
801- f_cpy .maker .fgraph .name = name
802828 return f_cpy
803829
804830 def _restore_defaults (self ):
@@ -808,7 +834,7 @@ def _restore_defaults(self):
808834 value = value .storage [0 ]
809835 self [i ] = value
810836
811- def __call__ (self , * args , ** kwargs ):
837+ def __call__ (self , * args , output_subset = None , ** kwargs ):
812838 """
813839 Evaluates value of a function on given arguments.
814840
@@ -842,7 +868,6 @@ def __call__(self, *args, **kwargs):
842868 if profile :
843869 t0 = time .perf_counter ()
844870
845- output_subset = kwargs .pop ("output_subset" , None )
846871 if output_subset is not None :
847872 warnings .warn ("output_subset is deprecated." , FutureWarning )
848873 if self .output_keys is not None :
@@ -993,37 +1018,18 @@ def __call__(self, *args, **kwargs):
9931018 if outputs is None :
9941019 outputs = [x .data for x in self .output_storage ]
9951020
996- # Remove internal references to required inputs.
997- # These cannot be re-used anyway.
998- for arg_container in input_storage :
999- if arg_container .required :
1000- arg_container .storage [0 ] = None
1001-
1002- # if we are allowing garbage collection, remove the
1003- # output reference from the internal storage cells
1004- if getattr (self .vm , "allow_gc" , False ):
1005- # strict=False because we are in a hot loop
1006- for o_container , o_variable in zip (
1007- self .output_storage , self .maker .fgraph .outputs , strict = False
1008- ):
1009- if o_variable .owner is not None :
1010- # this node is the variable of computation
1011- # WARNING: This circumvents the 'readonly' attribute in x
1012- o_container .storage [0 ] = None
1013-
1014- if getattr (self .vm , "need_update_inputs" , True ):
1015- # Update the inputs that have an update function
1016- # strict=False because we are in a hot loop
1017- for input , storage in reversed (
1018- list (zip (self .maker .expanded_inputs , input_storage , strict = False ))
1019- ):
1020- if input .update is not None :
1021- storage .data = outputs .pop ()
1022- else :
1023- outputs = outputs [: self .n_returned_outputs ]
1021+ # Set updates and filter them out from the returned outputs
1022+ for i , input_storage in self .update_input_storage :
1023+ input_storage .data = outputs [i ]
1024+ outputs = outputs [: self .n_returned_outputs ]
1025+
1026+ # Remove input and output values from storage data
1027+ for storage_data in self .clear_input_output_storage_data :
1028+ storage_data [0 ] = None
10241029
10251030 # Put default values back in the storage
1026- self ._restore_defaults ()
1031+ if self .defaults :
1032+ self ._restore_defaults ()
10271033
10281034 if profile :
10291035 dt_call = time .perf_counter () - t0
@@ -1039,25 +1045,21 @@ def __call__(self, *args, **kwargs):
10391045
10401046 if self .return_none :
10411047 return None
1042- elif self .unpack_single and len (outputs ) == 1 and output_subset is None :
1043- return outputs [0 ]
1044- else :
1045- if self .output_keys is not None :
1046- assert len (self .output_keys ) == len (outputs )
10471048
1048- if output_subset is None :
1049- # strict=False because we are in a hot loop
1050- return dict (zip (self .output_keys , outputs , strict = False ))
1051- else :
1052- return {
1053- self .output_keys [index ]: outputs [index ]
1054- for index in output_subset
1055- }
1049+ if output_subset is not None :
1050+ outputs = [outputs [i ] for i in output_subset ]
10561051
1057- if output_subset is None :
1058- return outputs
1052+ if self .output_keys is None :
1053+ if self .unpack_single :
1054+ [out ] = outputs
1055+ return out
10591056 else :
1060- return [outputs [i ] for i in output_subset ]
1057+ return outputs
1058+ else :
1059+ output_keys = self .output_keys
1060+ if output_subset is not None :
1061+ output_keys = [output_keys [i ] for i in output_subset ]
1062+ return dict (zip (output_keys , outputs , strict = True ))
10611063
10621064 value = property (
10631065 lambda self : self ._value ,
@@ -1091,10 +1093,6 @@ def get_shared(self):
10911093 """
10921094 return [i .variable for i in self .maker .inputs if i .implicit ]
10931095
1094- def sync_shared (self ):
1095- # NOTE: sync was needed on old gpu backend
1096- pass
1097-
10981096 def dprint (self , ** kwargs ):
10991097 """Debug print itself
11001098
0 commit comments