|
8 | 8 | from pytensor import Mode, config, function |
9 | 9 | from pytensor.graph import FunctionGraph |
10 | 10 | from pytensor.graph.op import HasInnerGraph |
| 11 | +from pytensor.tensor import matrix |
11 | 12 | from pytensor.tensor.basic import moveaxis |
12 | 13 | from pytensor.tensor.blockwise import Blockwise |
13 | 14 | from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum |
@@ -281,3 +282,15 @@ def test_threeway_mul(static_length): |
281 | 282 | out.eval({x: x_test, y: y_test, z: z_test}), |
282 | 283 | np.full((3,), fill_value=6), |
283 | 284 | ) |
| 285 | + |
| 286 | + |
| 287 | +def test_repeated_inputs(): |
| 288 | + x = matrix("x") |
| 289 | + out_repeated = einsum("ij,ij->i", x, x) |
| 290 | + out_copy = einsum("ij,ij->i", x, x.copy()) |
| 291 | + |
| 292 | + x_test = np.array([[1, 2], [3, 4]]) |
| 293 | + |
| 294 | + np.testing.assert_allclose( |
| 295 | + out_repeated.eval({x: x_test}), out_copy.eval({x: x_test}) |
| 296 | + ) |
0 commit comments