Skip to content

Commit d49b1b0

Browse files
committed
Handle inplace rewrites correctly in dispatch of OpFromGraph and Scan
JAX needs no special handling because it excludes inplace rewrites.
1 parent 52548fb commit d49b1b0

File tree

7 files changed

+104
-44
lines changed

7 files changed

+104
-44
lines changed

pytensor/compile/function/types.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,38 @@ def validate(self, fgraph):
168168
raise InconsistencyError(f"Trying to destroy a protected variable: {r}")
169169

170170

171+
def add_supervisor_to_fgraph(
172+
fgraph: FunctionGraph, input_specs, accept_inplace: bool = False
173+
) -> None:
174+
"""Setup Supervisor Feature in a FunctionGraph, so that inplace rewrites can be used."""
175+
176+
has_destroy_handler = hasattr(fgraph, "destroyers")
177+
if not (has_destroy_handler and accept_inplace):
178+
# Check if fgraph already contains destructive operations,
179+
# in which case we need to add a DestroyHandler or raise an error
180+
for node in fgraph.apply_nodes:
181+
if node.op.destroy_map:
182+
if not accept_inplace:
183+
raise TypeError(
184+
f"Graph must not contain inplace operations: {node}"
185+
)
186+
else:
187+
has_destroy_handler = True
188+
fgraph.attach_feature(DestroyHandler())
189+
break
190+
191+
# Protect all immutable inputs from inplace operations.
192+
fgraph.attach_feature(
193+
Supervisor(
194+
input
195+
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
196+
if not (
197+
spec.mutable or has_destroy_handler and fgraph.has_destroyers([input])
198+
)
199+
)
200+
)
201+
202+
171203
def std_fgraph(
172204
input_specs: list[SymbolicInput],
173205
output_specs: list[SymbolicOutput],
@@ -229,24 +261,8 @@ def std_fgraph(
229261

230262
found_updates.extend(map(SymbolicOutput, updates))
231263

232-
for node in fgraph.apply_nodes:
233-
if node.op.destroy_map:
234-
if not accept_inplace:
235-
raise TypeError(f"Graph must not contain inplace operations: {node}")
236-
else:
237-
fgraph.attach_feature(DestroyHandler())
238-
break
239-
240-
# We need to protect all immutable inputs from inplace operations.
241-
fgraph.attach_feature(
242-
Supervisor(
243-
input
244-
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
245-
if not (
246-
spec.mutable
247-
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
248-
)
249-
)
264+
add_supervisor_to_fgraph(
265+
fgraph=fgraph, input_specs=input_specs, accept_inplace=accept_inplace
250266
)
251267

252268
# If named nodes are replaced, keep the name

pytensor/link/jax/dispatch/scan.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import jax
22
import jax.numpy as jnp
33

4-
from pytensor.compile.mode import JAX
4+
from pytensor.compile.mode import JAX, get_mode
55
from pytensor.link.jax.dispatch.basic import jax_funcify
66
from pytensor.scan.op import Scan
77

@@ -19,7 +19,9 @@ def jax_funcify_Scan(op: Scan, **kwargs):
1919
)
2020

2121
# Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
22-
rewriter = op.mode_instance.excluding(*JAX._optimizer.exclude).optimizer
22+
rewriter = (
23+
get_mode(op.mode).including("jax").excluding(*JAX._optimizer.exclude).optimizer
24+
)
2325
rewriter(op.fgraph)
2426
scan_inner_func = jax_funcify(op.fgraph, **kwargs)
2527

pytensor/link/numba/dispatch/basic.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
1717
from numba.extending import box, overload
1818

19-
from pytensor import config
19+
from pytensor import In, config
2020
from pytensor.compile import NUMBA
2121
from pytensor.compile.builders import OpFromGraph
22+
from pytensor.compile.function.types import add_supervisor_to_fgraph
2223
from pytensor.compile.ops import DeepCopyOp
2324
from pytensor.graph.basic import Apply
2425
from pytensor.graph.fg import FunctionGraph
@@ -430,7 +431,13 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
430431
# TODO: Not sure this is the right place to do this, should we have a rewrite that
431432
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
432433
# The C-code defers it to the make_thunk phase
433-
NUMBA.optimizer(op.fgraph)
434+
fgraph = op.fgraph
435+
add_supervisor_to_fgraph(
436+
fgraph=fgraph,
437+
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
438+
accept_inplace=True,
439+
)
440+
NUMBA.optimizer(fgraph)
434441
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
435442

436443
if len(op.fgraph.outputs) == 1:

pytensor/link/numba/dispatch/scan.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from numba import types
55
from numba.extending import overload
66

7-
from pytensor.compile.mode import NUMBA
7+
from pytensor import In
8+
from pytensor.compile.function.types import add_supervisor_to_fgraph
9+
from pytensor.compile.mode import NUMBA, get_mode
810
from pytensor.link.numba.dispatch import basic as numba_basic
911
from pytensor.link.numba.dispatch.basic import (
1012
create_arg_string,
@@ -59,11 +61,18 @@ def numba_funcify_Scan(op, node, **kwargs):
5961
# explicitly triggers the optimization of the inner graphs of Scan?
6062
# The C-code defers it to the make_thunk phase
6163
rewriter = (
62-
op.mode_instance.including("numba")
64+
get_mode(op.mode)
65+
.including("numba")
6366
.excluding(*NUMBA._optimizer.exclude)
6467
.optimizer
6568
)
66-
rewriter(op.fgraph)
69+
fgraph = op.fgraph
70+
add_supervisor_to_fgraph(
71+
fgraph=fgraph,
72+
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
73+
accept_inplace=True,
74+
)
75+
rewriter(fgraph)
6776

6877
scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
6978

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import torch
66
import torch.compiler
77

8+
from pytensor import In
89
from pytensor.compile import PYTORCH
910
from pytensor.compile.builders import OpFromGraph
11+
from pytensor.compile.function.types import add_supervisor_to_fgraph
1012
from pytensor.compile.ops import DeepCopyOp
1113
from pytensor.graph.basic import Constant
1214
from pytensor.graph.fg import FunctionGraph
@@ -185,6 +187,13 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs):
185187
kwargs.pop("storage_map", None)
186188
# Apply inner rewrites
187189
PYTORCH.optimizer(op.fgraph)
190+
fgraph = op.fgraph
191+
add_supervisor_to_fgraph(
192+
fgraph=fgraph,
193+
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
194+
accept_inplace=True,
195+
)
196+
PYTORCH.optimizer(fgraph)
188197
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
189198
return fgraph_fn
190199

pytensor/scan/op.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from pytensor import tensor as pt
5858
from pytensor.compile.builders import construct_nominal_fgraph, infer_shape
5959
from pytensor.compile.function.pfunc import pfunc
60+
from pytensor.compile.function.types import add_supervisor_to_fgraph
6061
from pytensor.compile.io import In, Out
6162
from pytensor.compile.mode import Mode, get_mode
6263
from pytensor.compile.profiling import register_profiler_printer
@@ -834,8 +835,6 @@ def tensorConstructor(shape, dtype):
834835
self.n_outer_inputs = info.n_outer_inputs
835836
self.n_outer_outputs = info.n_outer_outputs
836837

837-
_ = self.prepare_fgraph(self.fgraph)
838-
839838
if any(node.op.destroy_map for node in self.fgraph.apply_nodes):
840839
raise InconsistencyError(
841840
"Inner-graphs must not contain in-place operations."
@@ -1394,23 +1393,8 @@ def prepare_fgraph(self, fgraph):
13941393

13951394
fgraph.update_mapping = update_mapping
13961395

1397-
from pytensor.compile.function.types import Supervisor
1398-
from pytensor.graph.destroyhandler import DestroyHandler
1399-
1400-
for node in fgraph.apply_nodes:
1401-
if node.op.destroy_map:
1402-
fgraph.attach_feature(DestroyHandler())
1403-
break
1404-
1405-
fgraph.attach_feature(
1406-
Supervisor(
1407-
inp
1408-
for spec, inp in zip(wrapped_inputs, fgraph.inputs, strict=True)
1409-
if not (
1410-
getattr(spec, "mutable", None)
1411-
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([inp]))
1412-
)
1413-
)
1396+
add_supervisor_to_fgraph(
1397+
fgraph=fgraph, input_specs=wrapped_inputs, accept_inplace=True
14141398
)
14151399

14161400
return wrapped_inputs, wrapped_outputs

tests/link/numba/test_basic.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,39 @@ def test_OpFromGraph():
835835
compare_numba_and_py([x, y, z], [out], [xv, yv, zv])
836836

837837

838+
@pytest.mark.filterwarnings("error")
839+
def test_ofg_inner_inplace():
840+
x = pt.vector("x")
841+
set0 = x[0].set(1) # SetSubtensor should not inplace on x
842+
exp_x = pt.exp(x)
843+
set1 = exp_x[0].set(1) # SetSubtensor should inplace on exp_x
844+
ofg0 = OpFromGraph([x], [set0])
845+
ofg1 = OpFromGraph([x], [set1])
846+
847+
y, z = pt.vectors("y", "z")
848+
fn = function([y, z], [ofg0(y), ofg1(z)], mode="NUMBA")
849+
850+
fn_ofg0 = fn.maker.fgraph.outputs[0].owner.op
851+
assert isinstance(fn_ofg0, OpFromGraph)
852+
fn_set0 = fn_ofg0.fgraph.outputs[0]
853+
assert fn_set0.owner.op.destroy_map == {}
854+
855+
fn_ofg1 = fn.maker.fgraph.outputs[1].owner.op
856+
assert isinstance(fn_ofg1, OpFromGraph)
857+
fn_set1 = fn_ofg1.fgraph.outputs[0]
858+
assert fn_set1.owner.op.destroy_map == {0: [0]}
859+
860+
x_test = np.array([0, 1, 1], dtype=config.floatX)
861+
y_test = np.array([0, 1, 1], dtype=config.floatX)
862+
res0, res1 = fn(x_test, y_test)
863+
# Check inputs were not mutated
864+
np.testing.assert_allclose(x_test, [0, 1, 1])
865+
np.testing.assert_allclose(y_test, [0, 1, 1])
866+
# Check outputs are correct
867+
np.testing.assert_allclose(res0, [1, 1, 1])
868+
np.testing.assert_allclose(res1, [1, np.e, np.e])
869+
870+
838871
@pytest.mark.filterwarnings("error")
839872
def test_cache_warning_suppressed():
840873
x = pt.vector("x", shape=(5,), dtype="float64")

0 commit comments

Comments
 (0)