Skip to content

Commit 061afbb

Browse files
committed
Added regression test for repeated inputs to the einsum function
1 parent f0f07c2 commit 061afbb

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

tests/tensor/test_einsum.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pytensor import Mode, config, function
99
from pytensor.graph import FunctionGraph
1010
from pytensor.graph.op import HasInnerGraph
11+
from pytensor.tensor import matrix
1112
from pytensor.tensor.basic import moveaxis
1213
from pytensor.tensor.blockwise import Blockwise
1314
from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum
@@ -281,3 +282,15 @@ def test_threeway_mul(static_length):
281282
out.eval({x: x_test, y: y_test, z: z_test}),
282283
np.full((3,), fill_value=6),
283284
)
285+
286+
287+
def test_repeated_inputs():
288+
x = matrix("x")
289+
out_repeated = einsum("ij,ij->i", x, x)
290+
out_copy = einsum("ij,ij->i", x, x.copy())
291+
292+
x_test = np.array([[1, 2], [3, 4]])
293+
294+
np.testing.assert_allclose(
295+
out_repeated.eval({x: x_test}), out_copy.eval({x: x_test})
296+
)

0 commit comments

Comments
 (0)