Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 56 additions & 18 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import time
import warnings
from collections.abc import Sequence
from itertools import chain
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -168,6 +169,59 @@
raise InconsistencyError(f"Trying to destroy a protected variable: {r}")


def add_supervisor_to_fgraph(
fgraph: FunctionGraph,
input_specs: Sequence[SymbolicInput],
accept_inplace: bool = False,
) -> None:
"""Setup Supervisor Feature in a FunctionGraph, so that inplace rewrites can be used.

Parameters
----------
fgraph: FunctionGraph
The FunctionGraph to setup the Supervisor Feature in.
input_specs: Sequence of SymbolicInput
The input specifications for the FunctionGraph.
Inputs with the attribute `mutable=False` and which are not already destroyed by an inplace operation
(if `accept_inplace` is True) will be protected from inplace operations.
Otherwise, they will be allowed to be destroyed.
accept_inplace: bool
Whether to allow inplace operations to already be present in the graph.

Raises
------
TypeError
If inplace operations are not allowed and the graph already contains inplace operations.

"""

has_destroy_handler = hasattr(fgraph, "destroyers")
if not (has_destroy_handler and accept_inplace):
# Check if fgraph already contains destructive operations,
# in which case we need to add a DestroyHandler or raise an error
for node in fgraph.apply_nodes:
if node.op.destroy_map:
if not accept_inplace:
raise TypeError(

Check warning on line 205 in pytensor/compile/function/types.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/function/types.py#L205

Added line #L205 was not covered by tests
f"Graph must not contain inplace operations: {node}"
)
else:
has_destroy_handler = True
fgraph.attach_feature(DestroyHandler())
break

# Protect all immutable inputs from inplace operations.
fgraph.attach_feature(
Supervisor(
input
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
if not (
spec.mutable or has_destroy_handler and fgraph.has_destroyers([input])
)
)
)


def std_fgraph(
input_specs: list[SymbolicInput],
output_specs: list[SymbolicOutput],
Expand Down Expand Up @@ -229,24 +283,8 @@

found_updates.extend(map(SymbolicOutput, updates))

for node in fgraph.apply_nodes:
if node.op.destroy_map:
if not accept_inplace:
raise TypeError(f"Graph must not contain inplace operations: {node}")
else:
fgraph.attach_feature(DestroyHandler())
break

# We need to protect all immutable inputs from inplace operations.
fgraph.attach_feature(
Supervisor(
input
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
if not (
spec.mutable
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
)
)
add_supervisor_to_fgraph(
fgraph=fgraph, input_specs=input_specs, accept_inplace=accept_inplace
)

# If named nodes are replaced, keep the name
Expand Down
6 changes: 5 additions & 1 deletion pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ def apply(self, fgraph):
break
if not supervisor_added:
warnings.warn(
f"A Supervisor feature is missing from {fgraph}.",
(
f"A Supervisor feature is missing from {fgraph}.\n"
"This is needed for inplace rewrites. Either exclude inplace rewrites or add a Supervisor feature.\n"
"A Supervisor feature can be added via `pytensor.compile.function.types.add_supervisor_to_fgraph`."
),
stacklevel=3,
)

Expand Down
6 changes: 4 additions & 2 deletions pytensor/link/jax/dispatch/scan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp

from pytensor.compile.mode import JAX
from pytensor.compile.mode import JAX, get_mode
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.scan.op import Scan

Expand All @@ -19,7 +19,9 @@ def jax_funcify_Scan(op: Scan, **kwargs):
)

# Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
rewriter = op.mode_instance.excluding(*JAX._optimizer.exclude).optimizer
rewriter = (
get_mode(op.mode).including("jax").excluding(*JAX._optimizer.exclude).optimizer
)
rewriter(op.fgraph)
scan_inner_func = jax_funcify(op.fgraph, **kwargs)

Expand Down
11 changes: 9 additions & 2 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from numba.extending import box, overload

from pytensor import config
from pytensor import In, config
from pytensor.compile import NUMBA
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
Expand Down Expand Up @@ -430,7 +431,13 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
# The C-code defers it to the make_thunk phase
NUMBA.optimizer(op.fgraph)
fgraph = op.fgraph
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
accept_inplace=True,
)
NUMBA.optimizer(fgraph)
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))

if len(op.fgraph.outputs) == 1:
Expand Down
15 changes: 12 additions & 3 deletions pytensor/link/numba/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from numba import types
from numba.extending import overload

from pytensor.compile.mode import NUMBA
from pytensor import In
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.mode import NUMBA, get_mode
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_arg_string,
Expand Down Expand Up @@ -59,11 +61,18 @@ def numba_funcify_Scan(op, node, **kwargs):
# explicitly triggers the optimization of the inner graphs of Scan?
# The C-code defers it to the make_thunk phase
rewriter = (
op.mode_instance.including("numba")
get_mode(op.mode)
.including("numba")
.excluding(*NUMBA._optimizer.exclude)
.optimizer
)
rewriter(op.fgraph)
fgraph = op.fgraph
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this scan is an inner function to something else, do we still want mutable = False ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now yes, this should be handled by inplace rewrites so that we know what inputs are safe to destroy by the time we get here. But I will still have to look at what those do exactly, and we weren't handling it before in the dispatch.

Scan has a very specific view of inplacing built around the constraint that they were compiling a full Pytensor function internally.

I'll open an issue to investigate

accept_inplace=True,
)
rewriter(fgraph)

scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))

Expand Down
4 changes: 2 additions & 2 deletions pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def convert_indices(indices, entry):
y_name = input_names[1]

if op.set_instead_of_inc:
function_name = "setsubtensor"
function_name = "set_subtensor"
index_body = f"z[indices] = {y_name}"
else:
function_name = "incsubtensor"
function_name = "inc_subtensor"
index_body = f"z[indices] += {y_name}"
else:
function_name = "subtensor"
Expand Down
9 changes: 9 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import torch
import torch.compiler

from pytensor import In
from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
Expand Down Expand Up @@ -185,6 +187,13 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None)
# Apply inner rewrites
PYTORCH.optimizer(op.fgraph)
fgraph = op.fgraph
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
accept_inplace=True,
)
PYTORCH.optimizer(fgraph)
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
return fgraph_fn

Expand Down
46 changes: 13 additions & 33 deletions pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@
from pytensor import tensor as pt
from pytensor.compile.builders import construct_nominal_fgraph, infer_shape
from pytensor.compile.function.pfunc import pfunc
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.io import In, Out
from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.mode import Mode, get_mode
from pytensor.compile.profiling import register_profiler_printer
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
Expand Down Expand Up @@ -761,18 +762,7 @@ def __init__(
self.profile = profile
self.allow_gc = allow_gc
self.strict = strict

# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
if self.name:
message = f"{self.name} sub profile"
else:
message = "Scan sub profile"

self.mode = get_default_mode() if mode is None else mode
self.mode_instance = get_mode(self.mode).clone(
link_kwargs=dict(allow_gc=self.allow_gc), message=message
)
self.mode = mode

# build a list of output types for any Apply node using this op.
self.output_types = []
Expand Down Expand Up @@ -845,8 +835,6 @@ def tensorConstructor(shape, dtype):
self.n_outer_inputs = info.n_outer_inputs
self.n_outer_outputs = info.n_outer_outputs

_ = self.prepare_fgraph(self.fgraph)

if any(node.op.destroy_map for node in self.fgraph.apply_nodes):
raise InconsistencyError(
"Inner-graphs must not contain in-place operations."
Expand Down Expand Up @@ -1405,23 +1393,8 @@ def prepare_fgraph(self, fgraph):

fgraph.update_mapping = update_mapping

from pytensor.compile.function.types import Supervisor
from pytensor.graph.destroyhandler import DestroyHandler

for node in fgraph.apply_nodes:
if node.op.destroy_map:
fgraph.attach_feature(DestroyHandler())
break

fgraph.attach_feature(
Supervisor(
inp
for spec, inp in zip(wrapped_inputs, fgraph.inputs, strict=True)
if not (
getattr(spec, "mutable", None)
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([inp]))
)
)
add_supervisor_to_fgraph(
fgraph=fgraph, input_specs=wrapped_inputs, accept_inplace=True
)

return wrapped_inputs, wrapped_outputs
Expand All @@ -1445,10 +1418,17 @@ def fn(self):
elif self.profile:
profile = self.profile

# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
mode_instance = get_mode(self.mode).clone(
link_kwargs=dict(allow_gc=self.allow_gc),
message=f"{self.name or 'Scan'} sub profile",
)

self._fn = pfunc(
wrapped_inputs,
wrapped_outputs,
mode=self.mode_instance,
mode=mode_instance,
accept_inplace=False,
profile=profile,
on_unused_input="ignore",
Expand Down
2 changes: 2 additions & 0 deletions pytensor/sparse/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def local_inplace_addsd_ccode(fgraph, node):
),
"fast_run",
"inplace",
"cxx_only",
position=50.1,
)

Expand Down Expand Up @@ -241,6 +242,7 @@ def local_addsd_ccode(fgraph, node):
WalkingGraphRewriter(local_addsd_ccode),
# Must be after local_inplace_addsd_ccode at 70.0
"fast_run",
"cxx_only",
position=70.1,
)

Expand Down
33 changes: 33 additions & 0 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,39 @@ def test_OpFromGraph():
compare_numba_and_py([x, y, z], [out], [xv, yv, zv])


@pytest.mark.filterwarnings("error")
def test_ofg_inner_inplace():
x = pt.vector("x")
set0 = x[0].set(1) # SetSubtensor should not inplace on x
exp_x = pt.exp(x)
set1 = exp_x[0].set(1) # SetSubtensor should inplace on exp_x
ofg0 = OpFromGraph([x], [set0])
ofg1 = OpFromGraph([x], [set1])

y, z = pt.vectors("y", "z")
fn = function([y, z], [ofg0(y), ofg1(z)], mode="NUMBA")

fn_ofg0 = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(fn_ofg0, OpFromGraph)
fn_set0 = fn_ofg0.fgraph.outputs[0]
assert fn_set0.owner.op.destroy_map == {}

fn_ofg1 = fn.maker.fgraph.outputs[1].owner.op
assert isinstance(fn_ofg1, OpFromGraph)
fn_set1 = fn_ofg1.fgraph.outputs[0]
assert fn_set1.owner.op.destroy_map == {0: [0]}

x_test = np.array([0, 1, 1], dtype=config.floatX)
y_test = np.array([0, 1, 1], dtype=config.floatX)
res0, res1 = fn(x_test, y_test)
# Check inputs were not mutated
np.testing.assert_allclose(x_test, [0, 1, 1])
np.testing.assert_allclose(y_test, [0, 1, 1])
# Check outputs are correct
np.testing.assert_allclose(res0, [1, 1, 1])
np.testing.assert_allclose(res1, [1, np.e, np.e])


@pytest.mark.filterwarnings("error")
def test_cache_warning_suppressed():
x = pt.vector("x", shape=(5,), dtype="float64")
Expand Down