Skip to content

Commit 004d765

Browse files
committed
Allow inplace of Elemwise ScalarLoop
1 parent 2774fcb commit 004d765

File tree

4 files changed

+47
-30
lines changed

4 files changed

+47
-30
lines changed

pytensor/scalar/basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4102,6 +4102,7 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
41024102

41034103
def __init__(self, *args, **kwargs):
41044104
self.prepare_node_called = set()
4105+
super().__init__(*args, **kwargs)
41054106

41064107
def _cleanup_graph(self, inputs, outputs):
41074108
# TODO: We could convert to TensorVariable, optimize graph,

pytensor/scalar/loop.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
constant: Sequence[Variable] | None = None,
5656
until: Variable | None = None,
5757
name="ScalarLoop",
58+
**kwargs,
5859
):
5960
if constant is None:
6061
constant = []
@@ -75,7 +76,7 @@ def __init__(
7576
self.nout = len(self.outputs)
7677
self.name = name
7778

78-
super().__init__()
79+
super().__init__(**kwargs)
7980

8081
def output_types(self, input_types):
8182
return self.outputs_type
@@ -115,7 +116,7 @@ def fgraph(self):
115116
self._fgraph = fgraph
116117
return self._fgraph
117118

118-
def clone(self):
119+
def clone(self, name=None, **kwargs):
119120
if self.is_while:
120121
*update, until = self.outputs
121122
else:
@@ -127,28 +128,16 @@ def clone(self):
127128
update=update,
128129
constant=constant,
129130
until=until,
130-
name=self.name,
131+
name=self.name if name is None else name,
132+
**kwargs,
131133
)
132134

133135
@property
134136
def fn(self):
135137
raise NotImplementedError
136138

137139
def make_new_inplace(self, output_types_preference=None, name=None):
138-
"""
139-
This op.__init__ fct don't have the same parameter as other scalar op.
140-
This break the insert_inplace_optimizer optimization.
141-
This fct allow fix patch this.
142-
143-
"""
144-
d = {k: getattr(self, k) for k in self.init_param}
145-
out = self.__class__(**d)
146-
if name:
147-
out.name = name
148-
else:
149-
name = out.name
150-
super(ScalarLoop, out).__init__(output_types_preference, name)
151-
return out
140+
return self.clone(output_types_preference=output_types_preference, name=name)
152141

153142
def make_node(self, n_steps, *inputs):
154143
assert len(inputs) == self.nin - 1

pytensor/tensor/rewriting/elemwise.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
)
2525
from pytensor.graph.rewriting.db import SequenceDB
2626
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
27-
from pytensor.scalar.loop import ScalarLoop
2827
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
2928
from pytensor.tensor.basic import (
3029
MakeVector,
@@ -74,15 +73,6 @@ def print_profile(cls, stream, prof, level=0):
7473
for n in sorted(ndim):
7574
print(blanc, n, ndim[n], file=stream)
7675

77-
def candidate_input_idxs(self, node):
78-
# TODO: Implement specialized InplaceCompositeOptimizer with logic
79-
# needed to correctly assign inplace for multi-output Composites
80-
# and ScalarLoops
81-
if isinstance(node.op.scalar_op, ScalarLoop):
82-
return []
83-
else:
84-
return range(len(node.outputs))
85-
8676
def apply(self, fgraph):
8777
r"""
8878
@@ -173,7 +163,7 @@ def apply(self, fgraph):
173163

174164
baseline = op.inplace_pattern
175165
candidate_outputs = [
176-
i for i in self.candidate_input_idxs(node) if i not in baseline
166+
i for i in range(len(node.outputs)) if i not in baseline
177167
]
178168
# node inputs that are Constant, already destroyed,
179169
# or fgraph protected inputs and fgraph outputs can't be used as
@@ -190,7 +180,7 @@ def apply(self, fgraph):
190180
]
191181
else:
192182
baseline = []
193-
candidate_outputs = self.candidate_input_idxs(node)
183+
candidate_outputs = range(len(node.outputs))
194184
# node inputs that are Constant, already destroyed,
195185
# fgraph protected inputs and fgraph outputs can't be used as inplace
196186
# target.

tests/scalar/test_loop.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import numpy as np
44
import pytest
55

6-
from pytensor import Mode, function
6+
from pytensor import In, Mode, function
7+
from pytensor.compile import get_default_mode
78
from pytensor.scalar import (
89
Composite,
910
as_scalar,
@@ -18,6 +19,8 @@
1819
)
1920
from pytensor.scalar.loop import ScalarLoop
2021
from pytensor.tensor import exp as tensor_exp
22+
from pytensor.tensor import vector
23+
from pytensor.tensor.elemwise import Elemwise
2124

2225

2326
mode = pytest.mark.parametrize(
@@ -255,3 +258,37 @@ def test_inner_loop(mode):
255258
out16,
256259
3**2 + 2.5,
257260
)
261+
262+
263+
def test_elemwise_inplace():
264+
x0 = float64("x0")
265+
y0 = float64("y0")
266+
x = x0 - y0
267+
y = y0 - x0
268+
op = Elemwise(ScalarLoop(init=[x0, y0], constant=[], update=[x, y]))
269+
270+
n_steps = vector("n_steps", dtype="int64")
271+
x0v = vector("x0")
272+
y0v = vector("y0")
273+
xv, yv = op(n_steps, x0v, y0v)
274+
275+
fn = function(
276+
[In(n_steps, mutable=True), In(x0v, mutable=True), In(y0v, mutable=True)],
277+
[xv, yv],
278+
mode=get_default_mode().including("inplace"),
279+
)
280+
elem_op = fn.maker.fgraph.outputs[0].owner.op
281+
assert isinstance(elem_op, Elemwise) and isinstance(elem_op.scalar_op, ScalarLoop)
282+
destroy_map = elem_op.destroy_map
283+
assert destroy_map in ({0: [1], 1: [2]}, {0: [2], 1: [2]})
284+
285+
n_test = np.array([1, 4, 8], dtype="int32")
286+
x0v_test = np.array([0, 0, 0], dtype=x0v.dtype)
287+
y0v_test = np.array([1, 1, 1], dtype=y0v.dtype)
288+
289+
xv_res, yv_res = fn(n_test, x0v_test, y0v_test)
290+
# Check the outputs are the destroyed inputs
291+
assert xv_res is (x0v_test, y0v_test)[destroy_map[0][0] - 1]
292+
assert yv_res is (x0v_test, y0v_test)[destroy_map[1][0] - 1]
293+
np.testing.assert_allclose(xv_res, [-1, -8, -128])
294+
np.testing.assert_allclose(yv_res, [1, 8, 128])

0 commit comments

Comments
 (0)