Skip to content

Commit 6518da2

Browse files
committed
fixed Einsum failing with repeated inputs
1 parent 2a7f3e1 commit 6518da2

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

pytensor/tensor/einsum.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,17 @@ def _right_to_left_path(n: int) -> tuple[tuple[int, int], ...]:
417417
return tuple(pairwise(reversed(range(n))))
418418

419419

420+
def _ensure_not_equal(elements):
421+
"""
422+
Ensures that any pair in a list of elements are not the same object. If a pair of elements is found to be equal, then one of them is converted to a copy.
423+
"""
424+
for i in range(len(elements)):
425+
for j in range(i + 1, len(elements)):
426+
if elements[i] == elements[j]:
427+
elements[j] = elements[j].copy()
428+
return elements
429+
430+
420431
def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVariable:
421432
"""
422433
Multiplication and summation of tensors using the Einstein summation convention.
@@ -553,7 +564,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
553564
"If you need this functionality open an issue in https://github.com/pymc-devs/pytensor/issues to let us know. "
554565
)
555566

556-
tensor_operands = [as_tensor(operand) for operand in operands]
567+
tensor_operands = _ensure_not_equal([as_tensor(operand) for operand in operands])
557568
shapes = [operand.type.shape for operand in tensor_operands]
558569

559570
path: PATH

0 commit comments

Comments
 (0)