Skip to content

Commit 0822179

Browse files
committed
Fix bug in einsum
A shortcut in the numpy implementation of einsum_path when there's nothing to optimize, creates a default path that can combine more than 2 operands. Our implementation only works with 2 or 1 operand operations at each step. https://github.com/numpy/numpy/blob/cc5851e654bfd82a23f2758be4bd224be84fc1c3/numpy/_core/einsumfunc.py#L945-L951
1 parent f08c191 commit 0822179

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

pytensor/tensor/einsum.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,12 @@ def _contraction_list_from_path(
410410
return contraction_list
411411

412412

413+
def _right_to_left_path(n: int) -> tuple[tuple[int, int], ...]:
414+
# Create a right to left contraction path
415+
# if n = 5, out = ((4, 3), (3, 2), (2, 1), (1, 0))
416+
return tuple(pairwise(reversed(range(n))))
417+
418+
413419
def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVariable:
414420
"""
415421
Multiplication and summation of tensors using the Einstein summation convention.
@@ -546,8 +552,6 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
546552
"If you need this functionality open an issue in https://github.com/pymc-devs/pytensor/issues to let us know. "
547553
)
548554

549-
# TODO: Is this doing something clever about unknown shapes?
550-
# contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
551555
tensor_operands = [as_tensor(operand) for operand in operands]
552556
shapes = [operand.type.shape for operand in tensor_operands]
553557

@@ -565,7 +569,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
565569
else:
566570
# By default, we try right to left because we assume that most graphs
567571
# have a lower dimensional rightmost operand
568-
path = tuple(pairwise(reversed(range(len(tensor_operands)))))
572+
path = _right_to_left_path(len(tensor_operands))
569573
contraction_list = _contraction_list_from_path(
570574
subscripts, tensor_operands, path
571575
)
@@ -584,6 +588,15 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
584588
optimize="optimal",
585589
) # type: ignore
586590
path = tuple(contraction[0] for contraction in contraction_list)
591+
592+
if len(path) == 1 and len(path[0]) > 2:
593+
# When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing
594+
# pairwise reductions, which our implementation below demands.
595+
path = _right_to_left_path(len(tensor_operands))
596+
contraction_list = _contraction_list_from_path(
597+
subscripts, tensor_operands, path
598+
)
599+
587600
optimized = True
588601

589602
def removechars(s, chars):
@@ -745,6 +758,7 @@ def filter_singleton_dims(operand, names, other_operand, other_names):
745758
batch_axes=(lhs_batch, rhs_batch),
746759
)
747760
else:
761+
print(operand_indices, contracted_names, einstr)
748762
raise ValueError(
749763
f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}"
750764
)

tests/tensor/test_einsum.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,22 @@ def test_broadcastable_dims():
262262
atol = 1e-12 if config.floatX == "float64" else 1e-2
263263
np.testing.assert_allclose(suboptimal_eval, np_eval, atol=atol)
264264
np.testing.assert_allclose(optimal_eval, np_eval, atol=atol)
265+
266+
267+
@pytest.mark.parametrize("static_length", [False, True])
268+
def test_threeway_mul(static_length):
269+
# Regression test for https://github.com/pymc-devs/pytensor/issues/1184
270+
# x, y, z = vectors("x", "y", "z")
271+
sh = (3,) if static_length else (None,)
272+
x = tensor("x", shape=sh)
273+
y = tensor("y", shape=sh)
274+
z = tensor("z", shape=sh)
275+
out = einsum("..., ..., ... -> ...", x, y, z)
276+
277+
x_test = np.ones((3,), dtype=x.dtype)
278+
y_test = x_test + 1
279+
z_test = x_test + 2
280+
np.testing.assert_allclose(
281+
out.eval({x: x_test, y: y_test, z: z_test}),
282+
np.full((3,), fill_value=6),
283+
)

0 commit comments

Comments
 (0)