Skip to content

Commit 128da2a

Browse files
committed
...
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 228fd7e commit 128da2a

File tree

10 files changed

+252
-187
lines changed

10 files changed

+252
-187
lines changed

pytensor/graph/op.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,10 +583,11 @@ def make_thunk(
583583
)
584584
return self.make_py_thunk(node, storage_map, compute_map, no_recycling)
585585

586-
def try_inplace_inputs(self, candidate_inputs: list[int]) -> "Op":
587-
"""Try to return a version of self that can inplace on candidate_inputs."""
586+
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
587+
"""Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs`."""
588588
# TODO: Document this in the Create your own op docs
589-
raise NotImplementedError()
589+
# By default, do nothing
590+
return self
590591

591592
def __str__(self):
592593
return getattr(type(self), "__name__", super().__str__())

pytensor/link/numba/dispatch/basic.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -660,12 +660,8 @@ def cholesky(a):
660660

661661
else:
662662
# TODO: Use SciPy's BLAS/LAPACK Cython wrappers.
663-
664663
warnings.warn(
665-
(
666-
"Numba will use object mode to allow the "
667-
"`lower` argument to `scipy.linalg.cholesky`."
668-
),
664+
"Numba will use object mode to allow the `lower=False` argument to `scipy.linalg.cholesky`.",
669665
UserWarning,
670666
)
671667

pytensor/tensor/rewriting/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytensor.tensor.rewriting.einsum
77
import pytensor.tensor.rewriting.elemwise
88
import pytensor.tensor.rewriting.extra_ops
9+
import pytensor.tensor.rewriting.inplace
910
import pytensor.tensor.rewriting.jax
1011
import pytensor.tensor.rewriting.linalg
1112
import pytensor.tensor.rewriting.math

pytensor/tensor/rewriting/blas.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -757,8 +757,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
757757
)
758758

759759

760-
# After destroyhandler(49.5) but before we try to make elemwise things
761-
# inplace (75)
760+
# After destroyhandler(49.5) but before we try to make elemwise things inplace (75)
762761
blas_opt_inplace = in2out(
763762
local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace"
764763
)

pytensor/tensor/rewriting/blockwise.py

Lines changed: 1 addition & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
import itertools
2-
from typing import Optional
3-
4-
from pytensor.compile import Supervisor
51
from pytensor.compile.mode import optdb
62
from pytensor.graph import Constant, node_rewriter
73
from pytensor.graph.replace import vectorize_node
8-
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
4+
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
95
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
106
from pytensor.tensor.blockwise import Blockwise
117
from pytensor.tensor.math import Dot
@@ -229,77 +225,3 @@ def local_blockwise_reshape(fgraph, node):
229225
new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)])
230226
copy_stack_trace(node.outputs[0], new_out)
231227
return [new_out]
232-
233-
234-
@node_rewriter([Blockwise], inplace=True)
235-
def node_blockwise_inplace(fgraph, node):
236-
# Find inputs that are candidates for inplacing
237-
blockwise_op = node.op
238-
239-
if blockwise_op.destroy_map:
240-
# Op already has inplace
241-
return False
242-
243-
core_op = blockwise_op.core_op
244-
batch_ndim = blockwise_op.batch_ndim(node)
245-
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
246-
247-
# TODO: Refactor this code, which is also present in Elemwise Inplacer
248-
protected_inputs = [
249-
f.protected for f in fgraph._features if isinstance(f, Supervisor)
250-
]
251-
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
252-
protected_inputs.extend(fgraph.outputs)
253-
254-
# TODO: Add test for the broadcastable logic (don't inplace inputs that are being broadcasted)
255-
candidate_inputs = [
256-
idx
257-
for idx, inp in enumerate(node.inputs)
258-
if (
259-
not isinstance(inp, Constant)
260-
and inp.type.broadcastable[:batch_ndim] == out_batch_bcast
261-
and not fgraph.has_destroyers([inp])
262-
and inp not in protected_inputs
263-
)
264-
]
265-
266-
if not candidate_inputs:
267-
return None
268-
269-
try:
270-
inplace_core_op = core_op.try_inplace_inputs(candidate_inputs)
271-
except NotImplementedError:
272-
return False
273-
274-
core_destroy_map = inplace_core_op.destroy_map
275-
276-
if not core_destroy_map:
277-
return False
278-
279-
# Check Op is not trying to inplace on non-candidate inputs
280-
for destroyed_inputs in core_destroy_map.values():
281-
for destroyed_input in destroyed_inputs:
282-
if destroyed_input not in candidate_inputs:
283-
raise ValueError("core_op did not respect candidate inputs")
284-
285-
# Recreate core_op with inplace
286-
inplace_blockwise_op = Blockwise(
287-
core_op=inplace_core_op,
288-
signature=blockwise_op.signature,
289-
name=blockwise_op.name,
290-
gufunc_spec=blockwise_op.gufunc_spec,
291-
destroy_map=core_destroy_map,
292-
)
293-
294-
return inplace_blockwise_op.make_node(*node.inputs).outputs
295-
296-
297-
# After destroyhandler(49.5) but before we try to make elemwise things inplace (75)
298-
blockwise_inplace = in2out(node_blockwise_inplace, name="blockwise_inplace")
299-
optdb.register(
300-
"blockwise_inplace",
301-
blockwise_inplace,
302-
"fast_run",
303-
"inplace",
304-
position=69.0,
305-
)

pytensor/tensor/rewriting/inplace.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import itertools
2+
3+
from pytensor.compile import Supervisor, optdb
4+
from pytensor.graph import Constant
5+
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
6+
from pytensor.tensor.blockwise import Blockwise
7+
from pytensor.tensor.slinalg import Cholesky
8+
9+
10+
def filter_allowed_inplace_inputs(fgraph, node) -> list[int]:
11+
protected_inputs = [
12+
f.protected for f in fgraph._features if isinstance(f, Supervisor)
13+
]
14+
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
15+
protected_inputs.extend(fgraph.outputs)
16+
17+
return [
18+
idx
19+
for idx, inp in enumerate(node.inputs)
20+
if (
21+
not isinstance(inp, Constant)
22+
and not fgraph.has_destroyers([inp])
23+
and inp not in protected_inputs
24+
)
25+
]
26+
27+
28+
def validate_inplace_inputs(allowed_inplace_inputs, destroy_map):
29+
# Check Op is not trying to inplace on non-candidate inputs
30+
for destroyed_inputs in destroy_map.values():
31+
for destroyed_input in destroyed_inputs:
32+
if destroyed_input not in allowed_inplace_inputs:
33+
raise ValueError(
34+
"Op destroy_map does not respect allowed_inplace_inputs"
35+
)
36+
37+
38+
def make_inplace_core_op(fgraph, node):
39+
# Find inputs that are candidates for inplacing
40+
op = node.op
41+
42+
if op.destroy_map:
43+
# Op already has inplace
44+
return None
45+
46+
allowed_inplace_inputs = filter_allowed_inplace_inputs(fgraph, node)
47+
48+
if not allowed_inplace_inputs:
49+
return None
50+
51+
inplace_op = op.inplace_on_inputs(allowed_inplace_inputs=allowed_inplace_inputs)
52+
53+
if not inplace_op.destroy_map:
54+
return None
55+
56+
validate_inplace_inputs(allowed_inplace_inputs, destroy_map=inplace_op.destroy_map)
57+
58+
out = inplace_op.make_node(*node.inputs).outputs
59+
copy_stack_trace(node.outputs, out)
60+
return inplace_op
61+
62+
63+
@node_rewriter([Cholesky], inplace=True)
64+
def linalg_inplace(fgraph, node):
65+
return make_inplace_core_op(fgraph, node)
66+
67+
68+
@node_rewriter(tracks=[Blockwise])
69+
def blockwise_inplace(fgraph, node):
70+
blockwise_op: Blockwise = node.op
71+
72+
if blockwise_op.destroy_map:
73+
# Op already has inplace
74+
return
75+
76+
batch_ndim = blockwise_op.batch_ndim(node)
77+
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
78+
79+
allowed_inplace_inputs = [
80+
idx
81+
for idx in filter_allowed_inplace_inputs(fgraph, node)
82+
# We can only inplace on inputs that are not being broadcasted
83+
if node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast
84+
]
85+
86+
inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
87+
allowed_inplace_inputs=allowed_inplace_inputs
88+
)
89+
90+
if not inplace_core_op.destroy_map:
91+
return None
92+
93+
validate_inplace_inputs(
94+
allowed_inplace_inputs, destroy_map=inplace_core_op.destroy_map
95+
)
96+
97+
# Recreate core_op with inplace
98+
inplace_blockwise_op = Blockwise(
99+
core_op=inplace_core_op,
100+
signature=blockwise_op.signature,
101+
name=blockwise_op.name,
102+
gufunc_spec=blockwise_op.gufunc_spec,
103+
destroy_map=inplace_core_op.destroy_map,
104+
)
105+
106+
out = inplace_blockwise_op.make_node(*node.inputs).outputs
107+
copy_stack_trace(node.outputs, out)
108+
return out
109+
110+
111+
# After destroyhandler(49.5) but before we try to make blas (70) and elemwise things inplace (75)
112+
optdb.register(
113+
"linalg_inplace",
114+
in2out(linalg_inplace),
115+
"fast_run",
116+
"inplace",
117+
position=69.0,
118+
)
119+
120+
optdb.register(
121+
"blockwise_inplace",
122+
in2out(blockwise_inplace),
123+
"fast_run",
124+
"inplace",
125+
position=69.0,
126+
)

pytensor/tensor/rewriting/linalg.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
from pytensor import Variable
66
from pytensor.graph import Apply, FunctionGraph
7-
from pytensor.graph.rewriting.basic import (copy_stack_trace, node_rewriter,
7+
from pytensor.graph.rewriting.basic import (
8+
copy_stack_trace,
9+
node_rewriter,
810
)
911
from pytensor.scalar.basic import Mul
1012
from pytensor.tensor.basic import (
@@ -609,26 +611,3 @@ def rewrite_inv_inv(fgraph, node):
609611
):
610612
return None
611613
return [potential_inner_inv.inputs[0]]
612-
613-
614-
cholesky_no_inplace = Cholesky(overwrite_a=False)
615-
cholesky_inplace = Cholesky(overwrite_a=True)
616-
617-
618-
@node_rewriter([cholesky_no_inplace], inplace=True)
619-
@node_rewriter([Cholesky], inplace=True)
620-
def local_inplace_cholesky(fgraph, node):
621-
return make_inplace(node, "overwrite_a")
622-
623-
624-
# After destroyhandler(49.5) but before we try to make elemwise things
625-
# inplace (75)
626-
linalg_opt_inplace = in2out(local_inplace_cholesky, name="linalg_opt_inplace")
627-
optdb.register(
628-
"InplaceLinalgOpt",
629-
linalg_opt_inplace,
630-
"fast_run",
631-
"inplace",
632-
"linalg_opt_inplace",
633-
position=69.0,
634-
)

0 commit comments

Comments
 (0)