diff --git a/pytensor/compile/debugmode.py b/pytensor/compile/debugmode.py index 5c51222a1b..384f9eb874 100644 --- a/pytensor/compile/debugmode.py +++ b/pytensor/compile/debugmode.py @@ -1966,6 +1966,12 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions If the outputs argument for pytensor.function was a list, then output_keys is None. If the outputs argument was a dict, then output_keys is a sorted list of the keys from that dict. + trust_input : bool, default False + If True, no input validation checks are performed when the function is + called. This includes checking the number of inputs, their types and + that multiple inputs are not aliased to each other. Failure to meet any + of these conditions can lead to computational errors or to the + interpreter crashing. Notes ----- @@ -1993,6 +1999,7 @@ def __init__( output_keys=None, name=None, no_fgraph_prep=False, + trust_input=False, ): self.mode = mode self.profile = profile @@ -2146,6 +2153,7 @@ def __init__( self.on_unused_input = on_unused_input # Used for the pickling/copy self.output_keys = output_keys self.name = name + self.trust_input = trust_input self.required = [(i.value is None) for i in self.inputs] self.refeed = [ diff --git a/pytensor/compile/function/__init__.py b/pytensor/compile/function/__init__.py index 7fa3a179ac..61e4aa3cfe 100644 --- a/pytensor/compile/function/__init__.py +++ b/pytensor/compile/function/__init__.py @@ -37,6 +37,7 @@ def function_dump( profile: bool | ProfileStats | None = None, on_unused_input: str | None = None, extra_tag_to_remove: str | None = None, + trust_input: bool = False, ): """ This is helpful to make a reproducible case for problems during PyTensor @@ -82,6 +83,7 @@ def function_dump( "allow_input_downcast": allow_input_downcast, "profile": profile, "on_unused_input": on_unused_input, + "trust_input": trust_input, } with Path(filename).open("wb") as f: pickler = pytensor.misc.pkl_utils.StripPickler( @@ -107,6 +109,7 @@ def function( allow_input_downcast: bool | None = None, profile: bool | ProfileStats | None = None, on_unused_input: str | None = None, + trust_input: bool = False, ): """ Return a :class:`callable object ` @@ -164,6 +167,12 @@ def function( on_unused_input What to do if a variable in the 'inputs' list is not used in the graph. Possible values are 'raise', 'warn', 'ignore' and None. + trust_input: bool, default False + If True, no input validation checks are performed when the function is + called. This includes checking the number of inputs, their types and + that multiple inputs are not aliased to each other. Failure to meet any + of these conditions can lead to computational errors or to the + interpreter crashing. Returns ------- @@ -310,7 +319,12 @@ def opt_log1p(node): "semantics, which disallow using updates and givens" ) fn = orig_function( - inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name + inputs, + outputs, + mode=mode, + accept_inplace=accept_inplace, + name=name, + trust_input=trust_input, ) else: # note: pfunc will also call orig_function -- orig_function is @@ -329,5 +343,6 @@ def opt_log1p(node): on_unused_input=on_unused_input, profile=profile, output_keys=output_keys, + trust_input=trust_input, ) return fn diff --git a/pytensor/compile/function/pfunc.py b/pytensor/compile/function/pfunc.py index b938cb6a55..749ec5cb42 100644 --- a/pytensor/compile/function/pfunc.py +++ b/pytensor/compile/function/pfunc.py @@ -377,6 +377,7 @@ def pfunc( on_unused_input=None, output_keys=None, fgraph: FunctionGraph | None = None, + trust_input: bool = False, ) -> Function: """ Function-constructor for graphs with shared variables. @@ -425,6 +426,12 @@ def pfunc( fgraph An existing `FunctionGraph` from which to construct the returned `Function`. When this is non-``None``, nothing is cloned. + trust_input : bool, default False + If True, no input validation checks are performed when the function is + called. This includes checking the number of inputs, their types and + that multiple inputs are not aliased to each other. Failure to meet any + of these conditions can lead to computational errors or to the + interpreter crashing. Returns ------- @@ -472,6 +479,7 @@ def pfunc( on_unused_input=on_unused_input, output_keys=output_keys, fgraph=fgraph, + trust_input=trust_input, ) diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index 0ccaa9e00b..9cc85f3d24 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -373,6 +373,7 @@ def __init__( return_none: bool, output_keys, maker: "FunctionMaker", + trust_input: bool = False, name: str | None = None, ): """ @@ -407,6 +408,12 @@ def __init__( TODO maker The `FunctionMaker` that created this instance. + trust_input : bool, default False + If True, no input validation checks are performed when the function is + called. This includes checking the number of inputs, their types and + that multiple inputs are not aliased to each other. Failure to meet any + of these conditions can lead to computational errors or to the + interpreter crashing. name A string name. """ @@ -420,7 +427,7 @@ def __init__( self.return_none = return_none self.maker = maker self.profile = None # reassigned in FunctionMaker.create - self.trust_input = False # If True, we don't check the input parameter + self.trust_input = trust_input # If True, we don't check the input parameter self.name = name self.nodes_with_inner_function = [] self.output_keys = output_keys @@ -1341,7 +1348,12 @@ class FunctionMaker: name : str An optional name for this function. If used, the profile mode will print the time spent in this function. - + trust_input : bool, default False + If True, no input validation checks are performed when the function is + called. This includes checking the number of inputs, their types and + that multiple inputs are not aliased to each other. Failure to meet any + of these conditions can lead to computational errors or to the + interpreter crashing. """ @staticmethod @@ -1507,6 +1519,7 @@ def __init__( output_keys=None, name=None, no_fgraph_prep=False, + trust_input=False, ): # Save the provided mode, not the instantiated mode. # The instantiated mode don't pickle and if we unpickle an PyTensor @@ -1609,6 +1622,7 @@ def __init__( self.on_unused_input = on_unused_input # Used for the pickling/copy self.output_keys = output_keys self.name = name + self.trust_input = trust_input self.required = [(i.value is None) for i in self.inputs] self.refeed = [ @@ -1726,6 +1740,7 @@ def create(self, input_storage=None, storage_map=None): self.return_none, self.output_keys, self, + trust_input=self.trust_input, name=self.name, ) @@ -1743,6 +1758,7 @@ def orig_function( on_unused_input=None, output_keys=None, fgraph: FunctionGraph | None = None, + trust_input: bool = False, ) -> Function: """ Return a Function that will calculate the outputs from the inputs. @@ -1773,7 +1789,12 @@ def orig_function( fgraph An existing `FunctionGraph` to use instead of constructing a new one from cloned `outputs`. - + trust_input : bool, default False + If True, no input validation checks are performed when the function is + called. This includes checking the number of inputs, their types and + that multiple inputs are not aliased to each other. Failure to meet any + of these conditions can lead to computational errors or to the + interpreter crashing. """ if profile: @@ -1806,6 +1827,7 @@ def orig_function( output_keys=output_keys, name=name, fgraph=fgraph, + trust_input=trust_input, ) with config.change_flags(compute_test_value="off"): fn = m.create(defaults) diff --git a/tests/compile/function/test_function.py b/tests/compile/function/test_function.py index 9f75ef15d8..d1f94dd689 100644 --- a/tests/compile/function/test_function.py +++ b/tests/compile/function/test_function.py @@ -54,6 +54,16 @@ def test_function_name(): assert regex.match(func.name) is not None +def test_trust_input(): + x = dvector() + y = shared(1) + z = x + y + f = function([x], z) + assert f.trust_input is False + f = function([x], z, trust_input=True) + assert f.trust_input is True + + class TestFunctionIn: def test_in_strict(self): a = dvector()