Skip to content

Commit 7a7e806

Browse files
committed
Added rewrite for inv(inv(x)) -> x
1 parent 05d376f commit 7a7e806

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,3 +539,21 @@ def svd_uv_merge(fgraph, node):
539539
or len(fgraph.clients[cl.outputs[2]]) > 0
540540
):
541541
return [cl.outputs[1]]
542+
543+
544+
@register_canonicalize
545+
@register_stabilize
546+
@node_rewriter([Blockwise])
547+
def rewrite_inv_inv(fgraph, node):
548+
valid_inverses = (MatrixInverse, MatrixPinv)
549+
if not (isinstance(node.op.core_op, valid_inverses)):
550+
return None
551+
552+
potential_inner_inv = node.inputs[0].owner
553+
if not (
554+
isinstance(potential_inner_inv.op, Blockwise)
555+
and isinstance(node.op.core_op, valid_inverses)
556+
):
557+
return None
558+
559+
return [potential_inner_inv.inputs[0]]

tests/tensor/rewriting/test_linalg.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor import function
1010
from pytensor import tensor as pt
1111
from pytensor.compile import get_default_mode
12+
from pytensor.compile.ops import DeepCopyOp
1213
from pytensor.configdefaults import config
1314
from pytensor.tensor import swapaxes
1415
from pytensor.tensor.blockwise import Blockwise
@@ -545,3 +546,15 @@ def test_svd_uv_merge():
545546
assert node.op.compute_uv
546547
svd_counter += 1
547548
assert svd_counter == 1
549+
550+
551+
def test_inv_inv_rewrite():
552+
x = pt.matrix("a")
553+
ii_x = pt.linalg.inv(pt.linalg.inv(x))
554+
f_rewritten = function([x], ii_x, mode="FAST_RUN")
555+
nodes = f_rewritten.maker.fgraph.apply_nodes
556+
557+
assert all(isinstance(node.op, DeepCopyOp) for node in nodes)
558+
559+
x_testing = np.random.rand(10, 10)
560+
np.testing.assert_allclose(f_rewritten(x_testing), x_testing)

0 commit comments

Comments
 (0)