diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index eae4cb4e8b..0ccaa9e00b 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -5,6 +5,7 @@ import logging import time import warnings +from collections.abc import Sequence from itertools import chain from typing import TYPE_CHECKING @@ -168,6 +169,59 @@ def validate(self, fgraph): raise InconsistencyError(f"Trying to destroy a protected variable: {r}") +def add_supervisor_to_fgraph( + fgraph: FunctionGraph, + input_specs: Sequence[SymbolicInput], + accept_inplace: bool = False, +) -> None: + """Setup Supervisor Feature in a FunctionGraph, so that inplace rewrites can be used. + + Parameters + ---------- + fgraph: FunctionGraph + The FunctionGraph to setup the Supervisor Feature in. + input_specs: Sequence of SymbolicInput + The input specifications for the FunctionGraph. + Inputs with the attribute `mutable=False` and which are not already destroyed by an inplace operation + (if `accept_inplace` is True) will be protected from inplace operations. + Otherwise, they will be allowed to be destroyed. + accept_inplace: bool + Whether to allow inplace operations to already be present in the graph. + + Raises + ------ + TypeError + If inplace operations are not allowed and the graph already contains inplace operations. + + """ + + has_destroy_handler = hasattr(fgraph, "destroyers") + if not (has_destroy_handler and accept_inplace): + # Check if fgraph already contains destructive operations, + # in which case we need to add a DestroyHandler or raise an error + for node in fgraph.apply_nodes: + if node.op.destroy_map: + if not accept_inplace: + raise TypeError( + f"Graph must not contain inplace operations: {node}" + ) + else: + has_destroy_handler = True + fgraph.attach_feature(DestroyHandler()) + break + + # Protect all immutable inputs from inplace operations. + fgraph.attach_feature( + Supervisor( + input + for spec, input in zip(input_specs, fgraph.inputs, strict=True) + if not ( + spec.mutable or has_destroy_handler and fgraph.has_destroyers([input]) + ) + ) + ) + + def std_fgraph( input_specs: list[SymbolicInput], output_specs: list[SymbolicOutput], @@ -229,24 +283,8 @@ def std_fgraph( found_updates.extend(map(SymbolicOutput, updates)) - for node in fgraph.apply_nodes: - if node.op.destroy_map: - if not accept_inplace: - raise TypeError(f"Graph must not contain inplace operations: {node}") - else: - fgraph.attach_feature(DestroyHandler()) - break - - # We need to protect all immutable inputs from inplace operations. - fgraph.attach_feature( - Supervisor( - input - for spec, input in zip(input_specs, fgraph.inputs, strict=True) - if not ( - spec.mutable - or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input])) - ) - ) + add_supervisor_to_fgraph( + fgraph=fgraph, input_specs=input_specs, accept_inplace=accept_inplace ) # If named nodes are replaced, keep the name diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 43a5e131cb..82931bced6 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -138,7 +138,11 @@ def apply(self, fgraph): break if not supervisor_added: warnings.warn( - f"A Supervisor feature is missing from {fgraph}.", + ( + f"A Supervisor feature is missing from {fgraph}.\n" + "This is needed for inplace rewrites. Either exclude inplace rewrites or add a Supervisor feature.\n" + "A Supervisor feature can be added via `pytensor.compile.function.types.add_supervisor_to_fgraph`." + ), stacklevel=3, ) diff --git a/pytensor/link/jax/dispatch/scan.py b/pytensor/link/jax/dispatch/scan.py index d98328f0cf..b638570bd1 100644 --- a/pytensor/link/jax/dispatch/scan.py +++ b/pytensor/link/jax/dispatch/scan.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp -from pytensor.compile.mode import JAX +from pytensor.compile.mode import JAX, get_mode from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.scan.op import Scan @@ -19,7 +19,9 @@ def jax_funcify_Scan(op: Scan, **kwargs): ) # Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode) - rewriter = op.mode_instance.excluding(*JAX._optimizer.exclude).optimizer + rewriter = ( + get_mode(op.mode).including("jax").excluding(*JAX._optimizer.exclude).optimizer + ) rewriter(op.fgraph) scan_inner_func = jax_funcify(op.fgraph, **kwargs) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index c66a237f06..7685c17d9c 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -16,9 +16,10 @@ from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 from numba.extending import box, overload -from pytensor import config +from pytensor import In, config from pytensor.compile import NUMBA from pytensor.compile.builders import OpFromGraph +from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.compile.ops import DeepCopyOp from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph @@ -430,7 +431,13 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs): # TODO: Not sure this is the right place to do this, should we have a rewrite that # explicitly triggers the optimization of the inner graphs of OpFromGraph? # The C-code defers it to the make_thunk phase - NUMBA.optimizer(op.fgraph) + fgraph = op.fgraph + add_supervisor_to_fgraph( + fgraph=fgraph, + input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs], + accept_inplace=True, + ) + NUMBA.optimizer(fgraph) fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs)) if len(op.fgraph.outputs) == 1: diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index cc75fc3742..62e4a0608f 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -4,7 +4,9 @@ from numba import types from numba.extending import overload -from pytensor.compile.mode import NUMBA +from pytensor import In +from pytensor.compile.function.types import add_supervisor_to_fgraph +from pytensor.compile.mode import NUMBA, get_mode from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( create_arg_string, @@ -59,11 +61,18 @@ def numba_funcify_Scan(op, node, **kwargs): # explicitly triggers the optimization of the inner graphs of Scan? # The C-code defers it to the make_thunk phase rewriter = ( - op.mode_instance.including("numba") + get_mode(op.mode) + .including("numba") .excluding(*NUMBA._optimizer.exclude) .optimizer ) - rewriter(op.fgraph) + fgraph = op.fgraph + add_supervisor_to_fgraph( + fgraph=fgraph, + input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs], + accept_inplace=True, + ) + rewriter(fgraph) scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph)) diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 8e79efda00..94cb51434d 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -77,10 +77,10 @@ def convert_indices(indices, entry): y_name = input_names[1] if op.set_instead_of_inc: - function_name = "setsubtensor" + function_name = "set_subtensor" index_body = f"z[indices] = {y_name}" else: - function_name = "incsubtensor" + function_name = "inc_subtensor" index_body = f"z[indices] += {y_name}" else: function_name = "subtensor" diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index ef4bf10637..c33b2e6227 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -5,8 +5,10 @@ import torch import torch.compiler +from pytensor import In from pytensor.compile import PYTORCH from pytensor.compile.builders import OpFromGraph +from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.compile.ops import DeepCopyOp from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph @@ -185,6 +187,13 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs): kwargs.pop("storage_map", None) # Apply inner rewrites PYTORCH.optimizer(op.fgraph) + fgraph = op.fgraph + add_supervisor_to_fgraph( + fgraph=fgraph, + input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs], + accept_inplace=True, + ) + PYTORCH.optimizer(fgraph) fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) return fgraph_fn diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 1dbc93b9fa..7e7e3b2cee 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -57,8 +57,9 @@ from pytensor import tensor as pt from pytensor.compile.builders import construct_nominal_fgraph, infer_shape from pytensor.compile.function.pfunc import pfunc +from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.compile.io import In, Out -from pytensor.compile.mode import Mode, get_default_mode, get_mode +from pytensor.compile.mode import Mode, get_mode from pytensor.compile.profiling import register_profiler_printer from pytensor.configdefaults import config from pytensor.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined @@ -761,18 +762,7 @@ def __init__( self.profile = profile self.allow_gc = allow_gc self.strict = strict - - # Clone mode_instance, altering "allow_gc" for the linker, - # and adding a message if we profile - if self.name: - message = f"{self.name} sub profile" - else: - message = "Scan sub profile" - - self.mode = get_default_mode() if mode is None else mode - self.mode_instance = get_mode(self.mode).clone( - link_kwargs=dict(allow_gc=self.allow_gc), message=message - ) + self.mode = mode # build a list of output types for any Apply node using this op. self.output_types = [] @@ -845,8 +835,6 @@ def tensorConstructor(shape, dtype): self.n_outer_inputs = info.n_outer_inputs self.n_outer_outputs = info.n_outer_outputs - _ = self.prepare_fgraph(self.fgraph) - if any(node.op.destroy_map for node in self.fgraph.apply_nodes): raise InconsistencyError( "Inner-graphs must not contain in-place operations." @@ -1405,23 +1393,8 @@ def prepare_fgraph(self, fgraph): fgraph.update_mapping = update_mapping - from pytensor.compile.function.types import Supervisor - from pytensor.graph.destroyhandler import DestroyHandler - - for node in fgraph.apply_nodes: - if node.op.destroy_map: - fgraph.attach_feature(DestroyHandler()) - break - - fgraph.attach_feature( - Supervisor( - inp - for spec, inp in zip(wrapped_inputs, fgraph.inputs, strict=True) - if not ( - getattr(spec, "mutable", None) - or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([inp])) - ) - ) + add_supervisor_to_fgraph( + fgraph=fgraph, input_specs=wrapped_inputs, accept_inplace=True ) return wrapped_inputs, wrapped_outputs @@ -1445,10 +1418,17 @@ def fn(self): elif self.profile: profile = self.profile + # Clone mode_instance, altering "allow_gc" for the linker, + # and adding a message if we profile + mode_instance = get_mode(self.mode).clone( + link_kwargs=dict(allow_gc=self.allow_gc), + message=f"{self.name or 'Scan'} sub profile", + ) + self._fn = pfunc( wrapped_inputs, wrapped_outputs, - mode=self.mode_instance, + mode=mode_instance, accept_inplace=False, profile=profile, on_unused_input="ignore", diff --git a/pytensor/sparse/rewriting.py b/pytensor/sparse/rewriting.py index 13735d2aca..4c958dd08f 100644 --- a/pytensor/sparse/rewriting.py +++ b/pytensor/sparse/rewriting.py @@ -210,6 +210,7 @@ def local_inplace_addsd_ccode(fgraph, node): ), "fast_run", "inplace", + "cxx_only", position=50.1, ) @@ -241,6 +242,7 @@ def local_addsd_ccode(fgraph, node): WalkingGraphRewriter(local_addsd_ccode), # Must be after local_inplace_addsd_ccode at 70.0 "fast_run", + "cxx_only", position=70.1, ) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 4857d2f932..654cbe7bd4 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -835,6 +835,39 @@ def test_OpFromGraph(): compare_numba_and_py([x, y, z], [out], [xv, yv, zv]) +@pytest.mark.filterwarnings("error") +def test_ofg_inner_inplace(): + x = pt.vector("x") + set0 = x[0].set(1) # SetSubtensor should not inplace on x + exp_x = pt.exp(x) + set1 = exp_x[0].set(1) # SetSubtensor should inplace on exp_x + ofg0 = OpFromGraph([x], [set0]) + ofg1 = OpFromGraph([x], [set1]) + + y, z = pt.vectors("y", "z") + fn = function([y, z], [ofg0(y), ofg1(z)], mode="NUMBA") + + fn_ofg0 = fn.maker.fgraph.outputs[0].owner.op + assert isinstance(fn_ofg0, OpFromGraph) + fn_set0 = fn_ofg0.fgraph.outputs[0] + assert fn_set0.owner.op.destroy_map == {} + + fn_ofg1 = fn.maker.fgraph.outputs[1].owner.op + assert isinstance(fn_ofg1, OpFromGraph) + fn_set1 = fn_ofg1.fgraph.outputs[0] + assert fn_set1.owner.op.destroy_map == {0: [0]} + + x_test = np.array([0, 1, 1], dtype=config.floatX) + y_test = np.array([0, 1, 1], dtype=config.floatX) + res0, res1 = fn(x_test, y_test) + # Check inputs were not mutated + np.testing.assert_allclose(x_test, [0, 1, 1]) + np.testing.assert_allclose(y_test, [0, 1, 1]) + # Check outputs are correct + np.testing.assert_allclose(res0, [1, 1, 1]) + np.testing.assert_allclose(res1, [1, np.e, np.e]) + + @pytest.mark.filterwarnings("error") def test_cache_warning_suppressed(): x = pt.vector("x", shape=(5,), dtype="float64")