Skip to content

Commit 862f158

Browse files
committed
Reduce overhead of Function call
1 parent 8593f34 commit 862f158

File tree

1 file changed

+61
-63
lines changed

1 file changed

+61
-63
lines changed

pytensor/compile/function/types.py

Lines changed: 61 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -540,15 +540,42 @@ def __contains__(self, item):
540540
self._value = ValueAttribute()
541541
self._container = ContainerAttribute()
542542

543-
# TODO: Get rid of all this `expanded_inputs` nonsense
544-
assert len(self.maker.expanded_inputs) == len(self.input_storage)
543+
update_storage = [
544+
container
545+
for inp, container in zip(
546+
self.maker.expanded_inputs, input_storage, strict=True
547+
)
548+
if inp.update is not None
549+
]
550+
# Updates are the last inner outputs that are not returned by Function.__call__
551+
self.n_returned_outputs = len(self.output_storage) - len(update_storage)
552+
553+
# Function.__call__ is responsible for updating the inputs, unless the vm promises to do it itself
554+
self.update_input_storage: tuple[int, Container] = ()
555+
if getattr(vm, "need_update_inputs", True):
556+
self.update_input_storage = tuple(
557+
zip(
558+
range(self.n_returned_outputs, len(output_storage)),
559+
update_storage,
560+
strict=True,
561+
)
562+
)
545563

546-
# This is used only when `vm.need_update_inputs` is `False`, because
547-
# we're using one of the VM objects and it is putting updates back into
548-
# the input containers all by itself.
549-
self.n_returned_outputs = len(self.output_storage) - sum(
550-
inp.update is not None for inp in self.maker.expanded_inputs
551-
)
564+
# In every function call we place inputs in the input_storage, and the vm places outputs in the output_storage
565+
# After the call, we want to erase (some of) these references, to allow Python to GC them if unused
566+
# Required input containers are the non-default inputs, must always be provided again, so we GC them
567+
self.clear_input_output_storage_data = [
568+
container.storage for container in input_storage if container.required
569+
]
570+
if getattr(vm, "allow_gc", False):
571+
# If the vm allows us to GC the outputs, do it
572+
self.clear_input_output_storage_data += [
573+
container.storage
574+
for container, variable in zip(
575+
self.output_storage, self.maker.fgraph.outputs, strict=True
576+
)
577+
if variable.owner is not None # Not a constant output
578+
]
552579

553580
for node in self.maker.fgraph.apply_nodes:
554581
if isinstance(node.op, HasInnerGraph):
@@ -747,7 +774,7 @@ def checkSV(sv_ori, sv_rpl):
747774
elif isinstance(profile, str):
748775
profile = pytensor.compile.profiling.ProfileStats(message=profile)
749776

750-
f_cpy = maker.__class__(
777+
f_cpy = type(maker)(
751778
inputs=ins,
752779
outputs=outs,
753780
fgraph=fg_cpy,
@@ -765,6 +792,8 @@ def checkSV(sv_ori, sv_rpl):
765792
# check that.
766793
accept_inplace=True,
767794
no_fgraph_prep=True,
795+
output_keys=maker.output_keys,
796+
name=name,
768797
).create(input_storage, storage_map=new_storage_map)
769798

770799
for in_ori, in_cpy, ori, cpy in zip(
@@ -796,9 +825,6 @@ def checkSV(sv_ori, sv_rpl):
796825
in_cpy.variable = swap[in_ori.variable]
797826

798827
f_cpy.trust_input = self.trust_input
799-
f_cpy.unpack_single = self.unpack_single
800-
f_cpy.name = name
801-
f_cpy.maker.fgraph.name = name
802828
return f_cpy
803829

804830
def _restore_defaults(self):
@@ -808,7 +834,7 @@ def _restore_defaults(self):
808834
value = value.storage[0]
809835
self[i] = value
810836

811-
def __call__(self, *args, **kwargs):
837+
def __call__(self, *args, output_subset=None, **kwargs):
812838
"""
813839
Evaluates value of a function on given arguments.
814840
@@ -842,7 +868,6 @@ def __call__(self, *args, **kwargs):
842868
if profile:
843869
t0 = time.perf_counter()
844870

845-
output_subset = kwargs.pop("output_subset", None)
846871
if output_subset is not None:
847872
warnings.warn("output_subset is deprecated.", FutureWarning)
848873
if self.output_keys is not None:
@@ -993,37 +1018,18 @@ def __call__(self, *args, **kwargs):
9931018
if outputs is None:
9941019
outputs = [x.data for x in self.output_storage]
9951020

996-
# Remove internal references to required inputs.
997-
# These cannot be re-used anyway.
998-
for arg_container in input_storage:
999-
if arg_container.required:
1000-
arg_container.storage[0] = None
1001-
1002-
# if we are allowing garbage collection, remove the
1003-
# output reference from the internal storage cells
1004-
if getattr(self.vm, "allow_gc", False):
1005-
# strict=False because we are in a hot loop
1006-
for o_container, o_variable in zip(
1007-
self.output_storage, self.maker.fgraph.outputs, strict=False
1008-
):
1009-
if o_variable.owner is not None:
1010-
# this node is the variable of computation
1011-
# WARNING: This circumvents the 'readonly' attribute in x
1012-
o_container.storage[0] = None
1013-
1014-
if getattr(self.vm, "need_update_inputs", True):
1015-
# Update the inputs that have an update function
1016-
# strict=False because we are in a hot loop
1017-
for input, storage in reversed(
1018-
list(zip(self.maker.expanded_inputs, input_storage, strict=False))
1019-
):
1020-
if input.update is not None:
1021-
storage.data = outputs.pop()
1022-
else:
1023-
outputs = outputs[: self.n_returned_outputs]
1021+
# Set updates and filter them out from the returned outputs
1022+
for i, input_storage in self.update_input_storage:
1023+
input_storage.data = outputs[i]
1024+
outputs = outputs[: self.n_returned_outputs]
1025+
1026+
# Remove input and output values from storage data
1027+
for storage_data in self.clear_input_output_storage_data:
1028+
storage_data[0] = None
10241029

10251030
# Put default values back in the storage
1026-
self._restore_defaults()
1031+
if self.defaults:
1032+
self._restore_defaults()
10271033

10281034
if profile:
10291035
dt_call = time.perf_counter() - t0
@@ -1039,25 +1045,21 @@ def __call__(self, *args, **kwargs):
10391045

10401046
if self.return_none:
10411047
return None
1042-
elif self.unpack_single and len(outputs) == 1 and output_subset is None:
1043-
return outputs[0]
1044-
else:
1045-
if self.output_keys is not None:
1046-
assert len(self.output_keys) == len(outputs)
10471048

1048-
if output_subset is None:
1049-
# strict=False because we are in a hot loop
1050-
return dict(zip(self.output_keys, outputs, strict=False))
1051-
else:
1052-
return {
1053-
self.output_keys[index]: outputs[index]
1054-
for index in output_subset
1055-
}
1049+
if output_subset is not None:
1050+
outputs = [outputs[i] for i in output_subset]
10561051

1057-
if output_subset is None:
1058-
return outputs
1052+
if self.output_keys is None:
1053+
if self.unpack_single:
1054+
[out] = outputs
1055+
return out
10591056
else:
1060-
return [outputs[i] for i in output_subset]
1057+
return outputs
1058+
else:
1059+
output_keys = self.output_keys
1060+
if output_subset is not None:
1061+
output_keys = [output_keys[i] for i in output_subset]
1062+
return dict(zip(output_keys, outputs, strict=True))
10611063

10621064
value = property(
10631065
lambda self: self._value,
@@ -1091,10 +1093,6 @@ def get_shared(self):
10911093
"""
10921094
return [i.variable for i in self.maker.inputs if i.implicit]
10931095

1094-
def sync_shared(self):
1095-
# NOTE: sync was needed on old gpu backend
1096-
pass
1097-
10981096
def dprint(self, **kwargs):
10991097
"""Debug print itself
11001098

0 commit comments

Comments
 (0)