File tree Expand file tree Collapse file tree 2 files changed +31
-0
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 2 files changed +31
-0
lines changed Original file line number Diff line number Diff line change @@ -539,3 +539,21 @@ def svd_uv_merge(fgraph, node):
539
539
or len (fgraph .clients [cl .outputs [2 ]]) > 0
540
540
):
541
541
return [cl .outputs [1 ]]
542
+
543
+
544
+ @register_canonicalize
545
+ @register_stabilize
546
+ @node_rewriter ([Blockwise ])
547
+ def rewrite_inv_inv (fgraph , node ):
548
+ valid_inverses = (MatrixInverse , MatrixPinv )
549
+ if not (isinstance (node .op .core_op , valid_inverses )):
550
+ return None
551
+
552
+ potential_inner_inv = node .inputs [0 ].owner
553
+ if not (
554
+ isinstance (potential_inner_inv .op , Blockwise )
555
+ and isinstance (node .op .core_op , valid_inverses )
556
+ ):
557
+ return None
558
+
559
+ return [potential_inner_inv .inputs [0 ]]
Original file line number Diff line number Diff line change 9
9
from pytensor import function
10
10
from pytensor import tensor as pt
11
11
from pytensor .compile import get_default_mode
12
+ from pytensor .compile .ops import DeepCopyOp
12
13
from pytensor .configdefaults import config
13
14
from pytensor .tensor import swapaxes
14
15
from pytensor .tensor .blockwise import Blockwise
@@ -545,3 +546,15 @@ def test_svd_uv_merge():
545
546
assert node .op .compute_uv
546
547
svd_counter += 1
547
548
assert svd_counter == 1
549
+
550
+
551
+ def test_inv_inv_rewrite ():
552
+ x = pt .matrix ("a" )
553
+ ii_x = pt .linalg .inv (pt .linalg .inv (x ))
554
+ f_rewritten = function ([x ], ii_x , mode = "FAST_RUN" )
555
+ nodes = f_rewritten .maker .fgraph .apply_nodes
556
+
557
+ assert all (isinstance (node .op , DeepCopyOp ) for node in nodes )
558
+
559
+ x_testing = np .random .rand (10 , 10 )
560
+ np .testing .assert_allclose (f_rewritten (x_testing ), x_testing )
You can’t perform that action at this time.
0 commit comments