Skip to content

Commit 228fd7e

Browse files
jessegrabowskiricardoV94
authored andcommitted
Resolve conflicts with main
Generalize Blockwise inplace logic Introduce `make_inplace` helper function for destructive rewrites Refactor cholesky destructive re-write to use `make_inplace` helper Add destructive in-place rewrite for `pt.linalg.cholesky`
1 parent 7eca252 commit 228fd7e

File tree

9 files changed

+374
-38
lines changed

9 files changed

+374
-38
lines changed

pytensor/graph/op.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,11 @@ def make_thunk(
583583
)
584584
return self.make_py_thunk(node, storage_map, compute_map, no_recycling)
585585

586+
def try_inplace_inputs(self, candidate_inputs: list[int]) -> "Op":
587+
"""Try to return a version of self that can inplace on candidate_inputs."""
588+
# TODO: Document this in the Create your own op docs
589+
raise NotImplementedError()
590+
586591
def __str__(self):
587592
return getattr(type(self), "__name__", super().__str__())
588593

pytensor/link/numba/dispatch/basic.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from pytensor.tensor.blas import BatchedDot
3737
from pytensor.tensor.math import Dot
3838
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
39-
from pytensor.tensor.slinalg import Solve
39+
from pytensor.tensor.slinalg import Cholesky, Solve
4040
from pytensor.tensor.type import TensorType
4141
from pytensor.tensor.type_other import MakeSlice, NoneConst
4242

@@ -646,6 +646,40 @@ def softplus(x):
646646
return softplus
647647

648648

649+
@numba_funcify.register(Cholesky)
650+
def numba_funcify_Cholesky(op, node, **kwargs):
651+
lower = op.lower
652+
out_dtype = node.outputs[0].type.numpy_dtype
653+
654+
if lower:
655+
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
656+
657+
@numba_njit
658+
def cholesky(a):
659+
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)
660+
661+
else:
662+
# TODO: Use SciPy's BLAS/LAPACK Cython wrappers.
663+
664+
warnings.warn(
665+
(
666+
"Numba will use object mode to allow the "
667+
"`lower` argument to `scipy.linalg.cholesky`."
668+
),
669+
UserWarning,
670+
)
671+
672+
ret_sig = get_numba_type(node.outputs[0].type)
673+
674+
@numba_njit
675+
def cholesky(a):
676+
with numba.objmode(ret=ret_sig):
677+
ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype)
678+
return ret
679+
680+
return cholesky
681+
682+
649683
@numba_funcify.register(Solve)
650684
def numba_funcify_Solve(op, node, **kwargs):
651685
assume_a = op.assume_a

pytensor/tensor/blockwise.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
signature: str | None = None,
4646
name: str | None = None,
4747
gufunc_spec: tuple[str, int, int] | None = None,
48+
destroy_map=None,
4849
**kwargs,
4950
):
5051
"""
@@ -79,6 +80,16 @@ def __init__(
7980
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
8081
self.gufunc_spec = gufunc_spec
8182
self._gufunc = None
83+
if destroy_map is not None:
84+
# TODO: Check core_op destroy_map is compatible with Blockwise destroy_map
85+
self.destroy_map = destroy_map
86+
if self.destroy_map != core_op.destroy_map:
87+
# Note: Should be fine for destroy_map of Blockwise to be more extensive than that of core_op
88+
# But we are not using that anywhere yet, so this check is fine for now
89+
raise ValueError(
90+
"Blockwise destroy_map must be the same as that of the core_op"
91+
)
92+
8293
super().__init__(**kwargs)
8394

8495
def __getstate__(self):

pytensor/tensor/rewriting/blockwise.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import itertools
2+
from typing import Optional
3+
4+
from pytensor.compile import Supervisor
15
from pytensor.compile.mode import optdb
26
from pytensor.graph import Constant, node_rewriter
37
from pytensor.graph.replace import vectorize_node
4-
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
8+
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
59
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
610
from pytensor.tensor.blockwise import Blockwise
711
from pytensor.tensor.math import Dot
@@ -56,7 +60,7 @@ def local_useless_unbatched_blockwise(fgraph, node):
5660
"fast_run",
5761
"fast_compile",
5862
"blockwise",
59-
position=49,
63+
position=99, # TODO: Check if this makes sense
6064
)
6165

6266

@@ -225,3 +229,77 @@ def local_blockwise_reshape(fgraph, node):
225229
new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)])
226230
copy_stack_trace(node.outputs[0], new_out)
227231
return [new_out]
232+
233+
234+
@node_rewriter([Blockwise], inplace=True)
235+
def node_blockwise_inplace(fgraph, node):
236+
# Find inputs that are candidates for inplacing
237+
blockwise_op = node.op
238+
239+
if blockwise_op.destroy_map:
240+
# Op already has inplace
241+
return False
242+
243+
core_op = blockwise_op.core_op
244+
batch_ndim = blockwise_op.batch_ndim(node)
245+
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
246+
247+
# TODO: Refactor this code, which is also present in Elemwise Inplacer
248+
protected_inputs = [
249+
f.protected for f in fgraph._features if isinstance(f, Supervisor)
250+
]
251+
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
252+
protected_inputs.extend(fgraph.outputs)
253+
254+
# TODO: Add test for the broadcastable logic (don't inplace inputs that are being broadcasted)
255+
candidate_inputs = [
256+
idx
257+
for idx, inp in enumerate(node.inputs)
258+
if (
259+
not isinstance(inp, Constant)
260+
and inp.type.broadcastable[:batch_ndim] == out_batch_bcast
261+
and not fgraph.has_destroyers([inp])
262+
and inp not in protected_inputs
263+
)
264+
]
265+
266+
if not candidate_inputs:
267+
return None
268+
269+
try:
270+
inplace_core_op = core_op.try_inplace_inputs(candidate_inputs)
271+
except NotImplementedError:
272+
return False
273+
274+
core_destroy_map = inplace_core_op.destroy_map
275+
276+
if not core_destroy_map:
277+
return False
278+
279+
# Check Op is not trying to inplace on non-candidate inputs
280+
for destroyed_inputs in core_destroy_map.values():
281+
for destroyed_input in destroyed_inputs:
282+
if destroyed_input not in candidate_inputs:
283+
raise ValueError("core_op did not respect candidate inputs")
284+
285+
# Recreate core_op with inplace
286+
inplace_blockwise_op = Blockwise(
287+
core_op=inplace_core_op,
288+
signature=blockwise_op.signature,
289+
name=blockwise_op.name,
290+
gufunc_spec=blockwise_op.gufunc_spec,
291+
destroy_map=core_destroy_map,
292+
)
293+
294+
return inplace_blockwise_op.make_node(*node.inputs).outputs
295+
296+
297+
# After destroyhandler(49.5) but before we try to make elemwise things inplace (75)
298+
blockwise_inplace = in2out(node_blockwise_inplace, name="blockwise_inplace")
299+
optdb.register(
300+
"blockwise_inplace",
301+
blockwise_inplace,
302+
"fast_run",
303+
"inplace",
304+
position=69.0,
305+
)

pytensor/tensor/rewriting/elemwise.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,8 @@ def apply(self, fgraph):
186186
for i in range(len(node.inputs))
187187
if i not in baseline.values()
188188
and not isinstance(node.inputs[i], Constant)
189-
and
190189
# the next line should not be costly most of the time.
191-
not fgraph.has_destroyers([node.inputs[i]])
190+
and not fgraph.has_destroyers([node.inputs[i]])
192191
and node.inputs[i] not in protected_inputs
193192
]
194193
else:

pytensor/tensor/rewriting/linalg.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44

55
from pytensor import Variable
66
from pytensor.graph import Apply, FunctionGraph
7-
from pytensor.graph.rewriting.basic import (
8-
copy_stack_trace,
9-
node_rewriter,
7+
from pytensor.graph.rewriting.basic import (copy_stack_trace, node_rewriter,
108
)
119
from pytensor.scalar.basic import Mul
1210
from pytensor.tensor.basic import (
@@ -611,3 +609,26 @@ def rewrite_inv_inv(fgraph, node):
611609
):
612610
return None
613611
return [potential_inner_inv.inputs[0]]
612+
613+
614+
cholesky_no_inplace = Cholesky(overwrite_a=False)
615+
cholesky_inplace = Cholesky(overwrite_a=True)
616+
617+
618+
@node_rewriter([cholesky_no_inplace], inplace=True)
619+
@node_rewriter([Cholesky], inplace=True)
620+
def local_inplace_cholesky(fgraph, node):
621+
return make_inplace(node, "overwrite_a")
622+
623+
624+
# After destroyhandler(49.5) but before we try to make elemwise things
625+
# inplace (75)
626+
linalg_opt_inplace = in2out(local_inplace_cholesky, name="linalg_opt_inplace")
627+
optdb.register(
628+
"InplaceLinalgOpt",
629+
linalg_opt_inplace,
630+
"fast_run",
631+
"inplace",
632+
"linalg_opt_inplace",
633+
position=69.0,
634+
)

0 commit comments

Comments
 (0)