Skip to content

Commit f0c903f

Browse files
committed
Deprecate output_subset and cleanup Function.__call__
Also do not profile calls by default
1 parent fa98568 commit f0c903f

File tree

1 file changed

+66
-75
lines changed

1 file changed

+66
-75
lines changed

pytensor/compile/function/types.py

Lines changed: 66 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,8 @@ class Function:
326326
def __init__(
327327
self,
328328
vm: "VM",
329-
input_storage,
330-
output_storage,
329+
input_storage: list[Container],
330+
output_storage: list[Container],
331331
indices,
332332
outputs,
333333
defaults,
@@ -388,6 +388,11 @@ def __init__(
388388
self.nodes_with_inner_function = []
389389
self.output_keys = output_keys
390390

391+
if self.output_keys is not None:
392+
warnings.warn(FutureWarning, "output_keys is deprecated.")
393+
394+
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
395+
391396
# See if we have any mutable / borrow inputs
392397
# TODO: this only need to be set if there is more than one input
393398
self._check_for_aliased_inputs = False
@@ -408,11 +413,6 @@ def __init__(
408413
finder = {}
409414
inv_finder = {}
410415

411-
def distribute(indices, cs, value):
412-
input.distribute(value, indices, cs)
413-
for c in cs:
414-
c.provided += 1
415-
416416
# Store the list of names of named inputs.
417417
named_inputs = []
418418
# Count the number of un-named inputs.
@@ -777,6 +777,13 @@ def checkSV(sv_ori, sv_rpl):
777777
f_cpy.maker.fgraph.name = name
778778
return f_cpy
779779

780+
def _restore_defaults(self):
781+
for i, (required, refeed, value) in enumerate(self.defaults):
782+
if refeed:
783+
if isinstance(value, Container):
784+
value = value.storage[0]
785+
self[i] = value
786+
780787
def __call__(self, *args, **kwargs):
781788
"""
782789
Evaluates value of a function on given arguments.
@@ -806,51 +813,45 @@ def __call__(self, *args, **kwargs):
806813
if ``output_subset`` is not passed.
807814
"""
808815

809-
def restore_defaults():
810-
for i, (required, refeed, value) in enumerate(self.defaults):
811-
if refeed:
812-
if isinstance(value, Container):
813-
value = value.storage[0]
814-
self[i] = value
815-
816816
profile = self.profile
817-
t0 = time.perf_counter()
817+
if profile is not None:
818+
t0 = time.perf_counter()
818819

819820
output_subset = kwargs.pop("output_subset", None)
820-
if output_subset is not None and self.output_keys is not None:
821-
output_subset = [self.output_keys.index(key) for key in output_subset]
821+
if output_subset is not None:
822+
warnings.warn(FutureWarning, "output_subset is deprecated.")
823+
if self.output_keys is not None:
824+
output_subset = [self.output_keys.index(key) for key in output_subset]
822825

823826
# Reinitialize each container's 'provided' counter
824827
if self.trust_input:
825-
i = 0
826-
for arg in args:
827-
s = self.input_storage[i]
828-
s.storage[0] = arg
829-
i += 1
828+
for arg_container, arg in zip(self.input_storage, args, strict=False):
829+
arg_container.storage[0] = arg
830830
else:
831-
for c in self.input_storage:
832-
c.provided = 0
831+
for arg_container in self.input_storage:
832+
arg_container.provided = 0
833833

834834
if len(args) + len(kwargs) > len(self.input_storage):
835835
raise TypeError("Too many parameter passed to pytensor function")
836836

837837
# Set positional arguments
838-
i = 0
839-
for arg in args:
838+
for arg_container, arg in zip(self.input_storage, args, strict=False):
840839
# TODO: provide a option for skipping the filter if we really
841840
# want speed.
842-
s = self.input_storage[i]
843841
# see this emails for a discuation about None as input
844842
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
845843
if arg is None:
846-
s.storage[0] = arg
844+
arg_container.storage[0] = arg
847845
else:
848846
try:
849-
s.storage[0] = s.type.filter(
850-
arg, strict=s.strict, allow_downcast=s.allow_downcast
847+
arg_container.storage[0] = arg_container.type.filter(
848+
arg,
849+
strict=arg_container.strict,
850+
allow_downcast=arg_container.allow_downcast,
851851
)
852852

853853
except Exception as e:
854+
i = self.input_storage.index(arg_container)
854855
function_name = "pytensor function"
855856
argument_name = "argument"
856857
if self.name:
@@ -875,27 +876,23 @@ def restore_defaults():
875876
+ function_name
876877
+ f" at index {int(i)} (0-based). {where}"
877878
) + e.args
878-
restore_defaults()
879+
self._restore_defaults()
879880
raise
880-
s.provided += 1
881-
i += 1
881+
arg_container.provided += 1
882882

883883
# Set keyword arguments
884884
if kwargs: # for speed, skip the items for empty kwargs
885885
for k, arg in kwargs.items():
886886
self[k] = arg
887887

888-
if (
889-
not self.trust_input
890-
and
891-
# The getattr is only needed for old pickle
892-
getattr(self, "_check_for_aliased_inputs", True)
893-
):
888+
if not self.trust_input:
894889
# Collect aliased inputs among the storage space
895890
args_share_memory = []
896-
for i in range(len(self.input_storage)):
897-
i_var = self.maker.inputs[i].variable
898-
i_val = self.input_storage[i].storage[0]
891+
for i, (inp, arg_storage) in enumerate(
892+
zip(self.maker.inputs, self.input_storage)
893+
):
894+
i_var = inp.variable
895+
i_val = arg_storage.storage[0]
899896
if hasattr(i_var.type, "may_share_memory"):
900897
is_aliased = False
901898
for j in range(len(args_share_memory)):
@@ -932,36 +929,36 @@ def restore_defaults():
932929
self.input_storage[j].storage[0]
933930
)
934931

935-
# Check if inputs are missing, or if inputs were set more than once, or
936-
# if we tried to provide inputs that are supposed to be implicit.
937-
if not self.trust_input:
938-
for c in self.input_storage:
939-
if c.required and not c.provided:
940-
restore_defaults()
932+
# Check if inputs are missing, or if inputs were set more than once, or
933+
# if we tried to provide inputs that are supposed to be implicit.
934+
for arg_container in self.input_storage:
935+
if arg_container.required and not arg_container.provided:
936+
self._restore_defaults()
941937
raise TypeError(
942-
f"Missing required input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
938+
f"Missing required input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
943939
)
944-
if c.provided > 1:
945-
restore_defaults()
940+
if arg_container.provided > 1:
941+
self._restore_defaults()
946942
raise TypeError(
947-
f"Multiple values for input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
943+
f"Multiple values for input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
948944
)
949-
if c.implicit and c.provided > 0:
950-
restore_defaults()
945+
if arg_container.implicit and arg_container.provided > 0:
946+
self._restore_defaults()
951947
raise TypeError(
952-
f"Tried to provide value for implicit input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
948+
f"Tried to provide value for implicit input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
953949
)
954950

955951
# Do the actual work
956-
t0_fn = time.perf_counter()
952+
if profile is not None:
953+
t0_fn = time.perf_counter()
957954
try:
958955
outputs = (
959956
self.vm()
960957
if output_subset is None
961958
else self.vm(output_subset=output_subset)
962959
)
963960
except Exception:
964-
restore_defaults()
961+
self._restore_defaults()
965962
if hasattr(self.vm, "position_of_error"):
966963
# this is a new vm-provided function or c linker
967964
# they need this because the exception manipulation
@@ -979,9 +976,9 @@ def restore_defaults():
979976
# old-style linkers raise their own exceptions
980977
raise
981978

982-
dt_fn = time.perf_counter() - t0_fn
983-
self.maker.mode.fn_time += dt_fn
984-
if profile:
979+
if profile is not None:
980+
dt_fn = time.perf_counter() - t0_fn
981+
self.maker.mode.fn_time += dt_fn
985982
profile.vm_call_time += dt_fn
986983

987984
# Retrieve the values that were computed
@@ -991,14 +988,13 @@ def restore_defaults():
991988

992989
# Remove internal references to required inputs.
993990
# These cannot be re-used anyway.
994-
for c in self.input_storage:
995-
if c.required:
996-
c.storage[0] = None
991+
for arg_container in self.input_storage:
992+
if arg_container.required:
993+
arg_container.storage[0] = None
997994

998995
# if we are allowing garbage collection, remove the
999996
# output reference from the internal storage cells
1000997
if getattr(self.vm, "allow_gc", False):
1001-
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
1002998
for o_container, o_variable in zip(
1003999
self.output_storage, self.maker.fgraph.outputs
10041000
):
@@ -1020,17 +1016,12 @@ def restore_defaults():
10201016
outputs = outputs[: self.n_returned_outputs]
10211017

10221018
# Put default values back in the storage
1023-
restore_defaults()
1024-
#
1025-
# NOTE: This logic needs to be replicated in
1026-
# scan.
1027-
# grep for 'PROFILE_CODE'
1028-
#
1029-
1030-
dt_call = time.perf_counter() - t0
1031-
pytensor.compile.profiling.total_fct_exec_time += dt_call
1032-
self.maker.mode.call_time += dt_call
1033-
if profile:
1019+
self._restore_defaults()
1020+
1021+
if profile is not None:
1022+
dt_call = time.perf_counter() - t0
1023+
pytensor.compile.profiling.total_fct_exec_time += dt_call
1024+
self.maker.mode.call_time += dt_call
10341025
profile.fct_callcount += 1
10351026
profile.fct_call_time += dt_call
10361027
if hasattr(self.vm, "update_profile"):

0 commit comments

Comments
 (0)