Skip to content

Commit d5b401f

Browse files
committed
added helper for solve with eye; paramterized tests
1 parent 7a7e806 commit d5b401f

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tests/tensor/rewriting/test_linalg.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from pytensor import function
1010
from pytensor import tensor as pt
1111
from pytensor.compile import get_default_mode
12-
from pytensor.compile.ops import DeepCopyOp
1312
from pytensor.configdefaults import config
1413
from pytensor.tensor import swapaxes
1514
from pytensor.tensor.blockwise import Blockwise
@@ -548,13 +547,17 @@ def test_svd_uv_merge():
548547
assert svd_counter == 1
549548

550549

551-
def test_inv_inv_rewrite():
550+
@pytest.mark.parametrize(
551+
"inv_op", [pt.linalg.inv, pt.linalg.pinv], ids=["MatrixInverse", "MatrixPinv"]
552+
)
553+
def test_inv_inv_rewrite(inv_op):
552554
x = pt.matrix("a")
553-
ii_x = pt.linalg.inv(pt.linalg.inv(x))
555+
ii_x = inv_op(inv_op(x))
554556
f_rewritten = function([x], ii_x, mode="FAST_RUN")
555557
nodes = f_rewritten.maker.fgraph.apply_nodes
556558

557-
assert all(isinstance(node.op, DeepCopyOp) for node in nodes)
559+
valid_inverses = (MatrixInverse, MatrixPinv)
560+
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
558561

559562
x_testing = np.random.rand(10, 10)
560563
np.testing.assert_allclose(f_rewritten(x_testing), x_testing)

0 commit comments

Comments
 (0)