|
9 | 9 | from pytensor import function
|
10 | 10 | from pytensor import tensor as pt
|
11 | 11 | from pytensor.compile import get_default_mode
|
12 |
| -from pytensor.compile.ops import DeepCopyOp |
13 | 12 | from pytensor.configdefaults import config
|
14 | 13 | from pytensor.tensor import swapaxes
|
15 | 14 | from pytensor.tensor.blockwise import Blockwise
|
@@ -548,13 +547,17 @@ def test_svd_uv_merge():
|
548 | 547 | assert svd_counter == 1
|
549 | 548 |
|
550 | 549 |
|
551 |
| -def test_inv_inv_rewrite(): |
| 550 | +@pytest.mark.parametrize( |
| 551 | + "inv_op", [pt.linalg.inv, pt.linalg.pinv], ids=["MatrixInverse", "MatrixPinv"] |
| 552 | +) |
| 553 | +def test_inv_inv_rewrite(inv_op): |
552 | 554 | x = pt.matrix("a")
|
553 |
| - ii_x = pt.linalg.inv(pt.linalg.inv(x)) |
| 555 | + ii_x = inv_op(inv_op(x)) |
554 | 556 | f_rewritten = function([x], ii_x, mode="FAST_RUN")
|
555 | 557 | nodes = f_rewritten.maker.fgraph.apply_nodes
|
556 | 558 |
|
557 |
| - assert all(isinstance(node.op, DeepCopyOp) for node in nodes) |
| 559 | + valid_inverses = (MatrixInverse, MatrixPinv) |
| 560 | + assert not any(isinstance(node.op, valid_inverses) for node in nodes) |
558 | 561 |
|
559 | 562 | x_testing = np.random.rand(10, 10)
|
560 | 563 | np.testing.assert_allclose(f_rewritten(x_testing), x_testing)
|
0 commit comments