Skip to content

Commit b6f6e5c

Browse files
committed
fixed rewrite for solve
1 parent 6c1d9c5 commit b6f6e5c

File tree

2 files changed

+81
-20
lines changed

2 files changed

+81
-20
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
Cholesky,
3737
Solve,
3838
SolveBase,
39+
SolveTriangular,
3940
block_diag,
4041
cholesky,
4142
solve,
@@ -542,28 +543,46 @@ def svd_uv_merge(fgraph, node):
542543

543544

544545
def _find_solve_with_eye(node):
546+
valid_solves = (Solve, SolveTriangular)
545547
# First, we look for the solve op
546-
if not (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve)):
547-
return None
548-
548+
if not (
549+
isinstance(node.op, Blockwise) and isinstance(node.op.core_op, valid_solves)
550+
):
551+
return False
549552
# Check whether second input to solve is Eye
550-
solve_inputs = node.owner.inputs
553+
solve_inputs = node.inputs
551554
potential_eye_input = solve_inputs[1].owner
552-
if not (isinstance(potential_eye_input.op, Eye)):
553-
return None
555+
if not (potential_eye_input and isinstance(potential_eye_input.op, Eye)):
556+
return False
554557

555-
return node
558+
return True
556559

557560

558561
@register_canonicalize
559562
@register_stabilize
560563
@node_rewriter([Blockwise])
561564
def rewrite_inv_inv(fgraph, node):
562-
valid_inverses = (MatrixInverse, MatrixPinv)
565+
valid_inverses = (MatrixInverse, MatrixPinv, Solve, SolveTriangular)
566+
valid_solves = (Solve, SolveTriangular)
567+
# Check if Solve has b = eye
568+
solve_inv_check = True
569+
if isinstance(node.op.core_op, valid_solves):
570+
solve_inv_check = _find_solve_with_eye(node)
571+
572+
if not solve_inv_check:
573+
return None
574+
563575
if not (isinstance(node.op.core_op, valid_inverses)):
564576
return None
565577

566578
potential_inner_inv = node.inputs[0].owner
579+
# Check if its an inner solve as well, does that have b = eye
580+
solve_inv_check = True
581+
if isinstance(potential_inner_inv.op.core_op, valid_solves):
582+
solve_inv_check = _find_solve_with_eye(potential_inner_inv)
583+
if not solve_inv_check:
584+
return None
585+
567586
if not (
568587
potential_inner_inv
569588
and isinstance(potential_inner_inv.op, Blockwise)

tests/tensor/rewriting/test_linalg.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -547,17 +547,59 @@ def test_svd_uv_merge():
547547
assert svd_counter == 1
548548

549549

550-
@pytest.mark.parametrize(
551-
"inv_op", [pt.linalg.inv, pt.linalg.pinv], ids=["MatrixInverse", "MatrixPinv"]
552-
)
553-
def test_inv_inv_rewrite(inv_op):
550+
def test_inv_inv_rewrite():
554551
x = pt.matrix("a")
555-
ii_x = inv_op(inv_op(x))
556-
f_rewritten = function([x], ii_x, mode="FAST_RUN")
557-
nodes = f_rewritten.maker.fgraph.apply_nodes
558-
552+
inv_ops = [pt.linalg.inv, pt.linalg.pinv]
553+
solve_ops = [pt.linalg.solve, pt.linalg.solve_triangular]
559554
valid_inverses = (MatrixInverse, MatrixPinv)
560-
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
561-
562-
x_testing = np.random.rand(10, 10)
563-
np.testing.assert_allclose(f_rewritten(x_testing), x_testing)
555+
valid_solves = (Solve, SolveTriangular)
556+
all_valid = (MatrixInverse, MatrixPinv, Solve, SolveTriangular)
557+
# inv(inv)
558+
for inv_op in inv_ops:
559+
for inv_op_2 in inv_ops:
560+
ii_x = inv_op(inv_op_2(x))
561+
f_rewritten = function([x], ii_x, mode="FAST_RUN")
562+
nodes = f_rewritten.maker.fgraph.apply_nodes
563+
564+
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
565+
566+
x_testing = np.random.rand(10, 10)
567+
np.testing.assert_allclose(f_rewritten(x_testing), x_testing)
568+
# solve(solve)
569+
b_eye = pt.eye(10)
570+
for solve_op in solve_ops:
571+
for solve_op_2 in solve_ops:
572+
ss_x = solve_op(solve_op_2(x, b_eye), b_eye)
573+
with pytensor.config.change_flags(optimizer_verbose=True):
574+
f_rewritten = function([x], ss_x, mode="FAST_RUN")
575+
nodes = f_rewritten.maker.fgraph.apply_nodes
576+
577+
assert not any(isinstance(node.op, valid_solves) for node in nodes)
578+
579+
x_testing = np.random.rand(10, 10)
580+
np.testing.assert_allclose(f_rewritten(x_testing), x_testing)
581+
582+
# inv(solve)
583+
for inv_op in inv_ops:
584+
for solve_op in solve_ops:
585+
is_x = inv_op(solve_op(x, b_eye))
586+
with pytensor.config.change_flags(optimizer_verbose=True):
587+
f_rewritten = function([x], is_x, mode="FAST_RUN")
588+
nodes = f_rewritten.maker.fgraph.apply_nodes
589+
assert not any(isinstance(node.op, all_valid) for node in nodes)
590+
591+
x_testing = np.random.rand(10, 10)
592+
np.testing.assert_allclose(f_rewritten(x_testing), x_testing)
593+
594+
# solve(inv)
595+
for solve_op in solve_ops:
596+
for inv_op in inv_ops:
597+
si_x = solve_op(inv_op(x), b_eye)
598+
with pytensor.config.change_flags(optimizer_verbose=True):
599+
f_rewritten = function([x], si_x, mode="FAST_RUN")
600+
nodes = f_rewritten.maker.fgraph.apply_nodes
601+
602+
assert not any(isinstance(node.op, all_valid) for node in nodes)
603+
604+
x_testing = np.random.rand(10, 10)
605+
np.testing.assert_allclose(f_rewritten(x_testing), x_testing)

0 commit comments

Comments
 (0)