Skip to content

Commit 0c43019

Browse files
committed
added rewrite for inverse of triangular matrix
1 parent 039ccfe commit 0c43019

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,3 +1038,22 @@ def det_triangular_to_prod_diag(fgraph, node):
10381038
return [det_val]
10391039

10401040
return None
1041+
1042+
1043+
@register_canonicalize
1044+
@register_stabilize
1045+
@node_rewriter([Blockwise])
1046+
def rewrite_inv_triangular_to_solve_triangular(fgraph, node):
1047+
core_op = node.op.core_op
1048+
if not (isinstance(core_op, ALL_INVERSE_OPS)):
1049+
return None
1050+
1051+
inputs = node.inputs[0]
1052+
triangular_check = _find_triangular_from_cholesky(inputs)
1053+
1054+
if triangular_check:
1055+
valid_eye = pt.eye(inputs.shape[-1])
1056+
inv_val = solve_triangular(inputs, valid_eye, lower=True)
1057+
return [inv_val]
1058+
1059+
return None

tests/tensor/rewriting/test_linalg.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1009,7 +1009,7 @@ def test_det_triangular():
10091009
nodes = f_rewritten.maker.fgraph.apply_nodes
10101010
assert not any(isinstance(node.op, Det) for node in nodes)
10111011

1012-
# Numeric test
1012+
# Numeric Test
10131013
x_test = np.random.rand(10, 10).astype(config.floatX)
10141014
x_psd = np.dot(x_test, x_test.T)
10151015
x_triangular = np.linalg.cholesky(x_psd)
@@ -1029,3 +1029,31 @@ def test_det_triangular():
10291029
f_rewritten = function([y], z, mode="FAST_RUN")
10301030
nodes = f_rewritten.maker.fgraph.apply_nodes
10311031
assert any(isinstance(node.op, Det) for node in nodes)
1032+
1033+
1034+
@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
1035+
def test_inv_triangular(inv_op):
1036+
x = pt.matrix("x")
1037+
x_triangular = pt.linalg.cholesky(x)
1038+
z = get_pt_function(x_triangular, inv_op)
1039+
1040+
# Rewrite Test
1041+
f_rewritten = function([x], z, mode="FAST_RUN")
1042+
nodes = f_rewritten.maker.fgraph.apply_nodes
1043+
1044+
valid_inverses = (MatrixInverse, MatrixPinv)
1045+
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
1046+
1047+
# Numeric Test
1048+
x_test = np.random.rand(10, 10).astype(config.floatX)
1049+
x_psd = np.dot(x_test, x_test.T)
1050+
x_triangular = np.linalg.cholesky(x_psd)
1051+
inv_val = np.linalg.inv(x_triangular)
1052+
rewritten_val = f_rewritten(x_psd)
1053+
1054+
assert_allclose(
1055+
inv_val,
1056+
rewritten_val,
1057+
atol=1e-3 if config.floatX == "float32" else 1e-8,
1058+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
1059+
)

0 commit comments

Comments
 (0)