Skip to content

Commit 652d0b6

Browse files
committed
Fix bug in local_dimshuffle_lift when elemwise has multiple outputs
1 parent 712660e commit 652d0b6

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,12 @@ def local_dimshuffle_lift(fgraph, node):
422422
inp = node.inputs[0]
423423
inode = inp.owner
424424
new_order = op.new_order
425-
if inode and isinstance(inode.op, Elemwise) and (len(fgraph.clients[inp]) == 1):
425+
if (
426+
inode
427+
and isinstance(inode.op, Elemwise)
428+
and len(inode.outputs) == 1
429+
and (len(fgraph.clients[inp]) == 1)
430+
):
426431
# Don't use make_node to have tag.test_value set.
427432
new_inputs = []
428433
for inp in inode.inputs:

tests/tensor/rewriting/test_elemwise.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pytensor.graph.rewriting.utils import rewrite_graph
2020
from pytensor.misc.safe_asarray import _asarray
2121
from pytensor.raise_op import assert_op
22-
from pytensor.scalar.basic import Composite
22+
from pytensor.scalar.basic import Composite, float64
2323
from pytensor.tensor.basic import MakeVector
2424
from pytensor.tensor.elemwise import DimShuffle, Elemwise
2525
from pytensor.tensor.math import abs as at_abs
@@ -163,6 +163,20 @@ def test_dimshuffle_on_broadcastable(self):
163163
# Check stacktrace was copied over correctly after rewrite was applied
164164
assert hasattr(g.outputs[0].tag, "trace")
165165

166+
def test_dimshuffle_lift_multi_out_elemwise(self):
167+
# Create a multi-output Elemwise Op with Composite
168+
x = float64("x")
169+
outs = [x + 1, x + 2]
170+
op = Elemwise(Composite([x], outs))
171+
172+
# Transpose both outputs
173+
x = matrix("x")
174+
outs = [out.T for out in op(x)]
175+
176+
# Make sure rewrite doesn't apply in this case
177+
g = FunctionGraph([x], outs)
178+
assert not local_dimshuffle_lift.transform(g, g.outputs[0].owner)
179+
166180

167181
def test_local_useless_dimshuffle_in_reshape():
168182
vec = TensorType(dtype="float64", shape=(None,))("vector")

0 commit comments

Comments
 (0)