Skip to content

Commit 1e42e10

Browse files
committed
added docstrings for rewrite and helper
1 parent fbdf031 commit 1e42e10

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -542,17 +542,22 @@ def svd_uv_merge(fgraph, node):
542542
return [cl.outputs[1]]
543543

544544

545-
def _find_solve_with_eye(node):
545+
def _find_solve_with_eye(node) -> bool:
546+
"""
547+
The result of solve(A, b) is the solution x to the linear equation Ax = b. If b is an identity matrix (Eye), x is simply inv(A).
548+
Here, we are just recognising whether the solve operation returns an inverse or not; not replacing it because solve is mathematically more stable than inv.
549+
"""
546550
valid_solves = (Solve, SolveTriangular)
547-
# First, we look for the solve op
551+
# First, we verify whether we have a valid solve op
548552
if not (
549553
isinstance(node.op, Blockwise) and isinstance(node.op.core_op, valid_solves)
550554
):
551555
return False
552-
# Check whether second input to solve is Eye
556+
# If the current op is solve, we check for b. If b is an identity matrix (Eye), we can return True
553557
solve_inputs = node.inputs
554558
eye_input = solve_inputs[1].owner
555559

560+
# We check for b = Eye and also make sure that if it was an Eye, then k = 0 (1's are only across the main diagonal)
556561
if not (eye_input and isinstance(eye_input.op, Eye)):
557562
return False
558563

@@ -565,9 +570,29 @@ def _find_solve_with_eye(node):
565570
@register_stabilize
566571
@node_rewriter([Blockwise])
567572
def rewrite_inv_inv(fgraph, node):
573+
"""
574+
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
575+
576+
Here, we check for direct inverse operations (inv/pinv) and also solve operations (solve/solve_triangular) in the case when b = Eye. This allows any combination of these "inverse" nodes to be simply rewritten.
577+
578+
Parameters
579+
----------
580+
fgraph: FunctionGraph
581+
Function graph being optimized
582+
node: Apply
583+
Node of the function graph to be optimized
584+
585+
Returns
586+
-------
587+
list of Variable, optional
588+
List of optimized variables, or None if no optimization was performed
589+
"""
568590
valid_inverses = (MatrixInverse, MatrixPinv, Solve, SolveTriangular)
569591
valid_solves = (Solve, SolveTriangular)
570592
# Check if its a valid inverse operation (either inv/pinv or if its solve, then b = eye)
593+
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
594+
# If the outer operation is a solve op with b = Eye, it treats it as inverse and finds the inner operation
595+
# If the outer operation is not an inverse (neither inv nor solve with eye), we do not apply this rewrite
571596
inv_check = False
572597
if isinstance(node.op, Blockwise) and isinstance(node.op.core_op, valid_inverses):
573598
inv_check = True
@@ -581,16 +606,17 @@ def rewrite_inv_inv(fgraph, node):
581606
if potential_inner_inv is None or potential_inner_inv.op is None:
582607
return None
583608

584-
# Check if its a valid inverse operation (either inv/pinv or if its solve, then b = eye)
585-
inv_check = False
609+
# Similar to the check for outer operation, we now run the same checks for the inner op.
610+
# If its an inverse or solve with eye, we apply the rewrite. Otherwise, we return None.
611+
inv_check_inner = False
586612
if isinstance(potential_inner_inv.op, Blockwise) and isinstance(
587613
potential_inner_inv.op.core_op, valid_inverses
588614
):
589-
inv_check = True
615+
inv_check_inner = True
590616
if isinstance(potential_inner_inv.op.core_op, valid_solves):
591-
inv_check = _find_solve_with_eye(potential_inner_inv)
617+
inv_check_inner = _find_solve_with_eye(potential_inner_inv)
592618

593-
if not inv_check:
619+
if not inv_check_inner:
594620
return None
595621

596622
if not (

tests/tensor/rewriting/test_linalg.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,7 @@ def test_transinv_to_invtrans():
9494
X = matrix("X")
9595
Y = matrix_inverse(X)
9696
Z = Y.transpose()
97-
print(Z.dprint())
9897
f = pytensor.function([X], Z)
99-
print(f.dprint())
10098
if config.mode != "FAST_COMPILE":
10199
for node in f.maker.fgraph.toposort():
102100
if isinstance(node.op, MatrixInverse):

0 commit comments

Comments
 (0)