Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pytensor/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1966,6 +1966,12 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
If the outputs argument for pytensor.function was a list, then
output_keys is None. If the outputs argument was a dict, then
output_keys is a sorted list of the keys from that dict.
trust_input : bool, default False
If True, no input validation checks are performed when the function is
called. This includes checking the number of inputs, their types and
that multiple inputs are not aliased to each other. Failure to meet any
of these conditions can lead to computational errors or to the
interpreter crashing.

Notes
-----
Expand Down Expand Up @@ -1993,6 +1999,7 @@ def __init__(
output_keys=None,
name=None,
no_fgraph_prep=False,
trust_input=False,
):
self.mode = mode
self.profile = profile
Expand Down Expand Up @@ -2146,6 +2153,7 @@ def __init__(
self.on_unused_input = on_unused_input # Used for the pickling/copy
self.output_keys = output_keys
self.name = name
self.trust_input = trust_input

self.required = [(i.value is None) for i in self.inputs]
self.refeed = [
Expand Down
17 changes: 16 additions & 1 deletion pytensor/compile/function/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def function_dump(
profile: bool | ProfileStats | None = None,
on_unused_input: str | None = None,
extra_tag_to_remove: str | None = None,
trust_input: bool = False,
):
"""
This is helpful to make a reproducible case for problems during PyTensor
Expand Down Expand Up @@ -82,6 +83,7 @@ def function_dump(
"allow_input_downcast": allow_input_downcast,
"profile": profile,
"on_unused_input": on_unused_input,
"trust_input": trust_input,
}
with Path(filename).open("wb") as f:
pickler = pytensor.misc.pkl_utils.StripPickler(
Expand All @@ -107,6 +109,7 @@ def function(
allow_input_downcast: bool | None = None,
profile: bool | ProfileStats | None = None,
on_unused_input: str | None = None,
trust_input: bool = False,
):
"""
Return a :class:`callable object <pytensor.compile.function.types.Function>`
Expand Down Expand Up @@ -164,6 +167,12 @@ def function(
on_unused_input
What to do if a variable in the 'inputs' list is not used in the graph.
Possible values are 'raise', 'warn', 'ignore' and None.
trust_input: bool, default False
If True, no input validation checks are performed when the function is
called. This includes checking the number of inputs, their types and
that multiple inputs are not aliased to each other. Failure to meet any
of these conditions can lead to computational errors or to the
interpreter crashing.

Returns
-------
Expand Down Expand Up @@ -310,7 +319,12 @@ def opt_log1p(node):
"semantics, which disallow using updates and givens"
)
fn = orig_function(
inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
inputs,
outputs,
mode=mode,
accept_inplace=accept_inplace,
name=name,
trust_input=trust_input,
)
else:
# note: pfunc will also call orig_function -- orig_function is
Expand All @@ -329,5 +343,6 @@ def opt_log1p(node):
on_unused_input=on_unused_input,
profile=profile,
output_keys=output_keys,
trust_input=trust_input,
)
return fn
8 changes: 8 additions & 0 deletions pytensor/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def pfunc(
on_unused_input=None,
output_keys=None,
fgraph: FunctionGraph | None = None,
trust_input: bool = False,
) -> Function:
"""
Function-constructor for graphs with shared variables.
Expand Down Expand Up @@ -425,6 +426,12 @@ def pfunc(
fgraph
An existing `FunctionGraph` from which to construct the returned
`Function`. When this is non-``None``, nothing is cloned.
trust_input : bool, default False
If True, no input validation checks are performed when the function is
called. This includes checking the number of inputs, their types and
that multiple inputs are not aliased to each other. Failure to meet any
of these conditions can lead to computational errors or to the
interpreter crashing.

Returns
-------
Expand Down Expand Up @@ -472,6 +479,7 @@ def pfunc(
on_unused_input=on_unused_input,
output_keys=output_keys,
fgraph=fgraph,
trust_input=trust_input,
)


Expand Down
28 changes: 25 additions & 3 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def __init__(
return_none: bool,
output_keys,
maker: "FunctionMaker",
trust_input: bool = False,
name: str | None = None,
):
"""
Expand Down Expand Up @@ -407,6 +408,12 @@ def __init__(
TODO
maker
The `FunctionMaker` that created this instance.
trust_input : bool, default False
If True, no input validation checks are performed when the function is
called. This includes checking the number of inputs, their types and
that multiple inputs are not aliased to each other. Failure to meet any
of these conditions can lead to computational errors or to the
interpreter crashing.
name
A string name.
"""
Expand All @@ -420,7 +427,7 @@ def __init__(
self.return_none = return_none
self.maker = maker
self.profile = None # reassigned in FunctionMaker.create
self.trust_input = False # If True, we don't check the input parameter
self.trust_input = trust_input # If True, we don't check the input parameter
self.name = name
self.nodes_with_inner_function = []
self.output_keys = output_keys
Expand Down Expand Up @@ -1341,7 +1348,12 @@ class FunctionMaker:
name : str
An optional name for this function. If used, the profile mode will
print the time spent in this function.

trust_input : bool, default False
If True, no input validation checks are performed when the function is
called. This includes checking the number of inputs, their types and
that multiple inputs are not aliased to each other. Failure to meet any
of these conditions can lead to computational errors or to the
interpreter crashing.
"""

@staticmethod
Expand Down Expand Up @@ -1507,6 +1519,7 @@ def __init__(
output_keys=None,
name=None,
no_fgraph_prep=False,
trust_input=False,
):
# Save the provided mode, not the instantiated mode.
# The instantiated mode don't pickle and if we unpickle an PyTensor
Expand Down Expand Up @@ -1609,6 +1622,7 @@ def __init__(
self.on_unused_input = on_unused_input # Used for the pickling/copy
self.output_keys = output_keys
self.name = name
self.trust_input = trust_input

self.required = [(i.value is None) for i in self.inputs]
self.refeed = [
Expand Down Expand Up @@ -1726,6 +1740,7 @@ def create(self, input_storage=None, storage_map=None):
self.return_none,
self.output_keys,
self,
trust_input=self.trust_input,
name=self.name,
)

Expand All @@ -1743,6 +1758,7 @@ def orig_function(
on_unused_input=None,
output_keys=None,
fgraph: FunctionGraph | None = None,
trust_input: bool = False,
) -> Function:
"""
Return a Function that will calculate the outputs from the inputs.
Expand Down Expand Up @@ -1773,7 +1789,12 @@ def orig_function(
fgraph
An existing `FunctionGraph` to use instead of constructing a new one
from cloned `outputs`.

trust_input : bool, default False
If True, no input validation checks are performed when the function is
called. This includes checking the number of inputs, their types and
that multiple inputs are not aliased to each other. Failure to meet any
of these conditions can lead to computational errors or to the
interpreter crashing.
"""

if profile:
Expand Down Expand Up @@ -1806,6 +1827,7 @@ def orig_function(
output_keys=output_keys,
name=name,
fgraph=fgraph,
trust_input=trust_input,
)
with config.change_flags(compute_test_value="off"):
fn = m.create(defaults)
Expand Down
10 changes: 10 additions & 0 deletions tests/compile/function/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ def test_function_name():
assert regex.match(func.name) is not None


def test_trust_input():
x = dvector()
y = shared(1)
z = x + y
f = function([x], z)
assert f.trust_input is False
f = function([x], z, trust_input=True)
assert f.trust_input is True


class TestFunctionIn:
def test_in_strict(self):
a = dvector()
Expand Down