Skip to content

Commit 039ccfe

Browse files
committed
add triangular helper and det triangular rewrite
1 parent 231a977 commit 039ccfe

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,3 +1013,28 @@ def slogdet_specialization(fgraph, node):
10131013
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
10141014
}
10151015
return replacements
1016+
1017+
1018+
def _find_triangular_from_cholesky(potential_triangular):
1019+
if not (
1020+
potential_triangular.owner is not None
1021+
and isinstance(potential_triangular.owner.op, Blockwise)
1022+
and isinstance(potential_triangular.owner.op.core_op, Cholesky)
1023+
):
1024+
return None
1025+
1026+
return potential_triangular
1027+
1028+
1029+
@register_canonicalize
1030+
@register_stabilize
1031+
@node_rewriter([det])
1032+
def det_triangular_to_prod_diag(fgraph, node):
1033+
inputs = node.inputs[0]
1034+
triangular_check = _find_triangular_from_cholesky(inputs)
1035+
1036+
if triangular_check:
1037+
det_val = inputs.diagonal().prod()
1038+
return [det_val]
1039+
1040+
return None

tests/tensor/rewriting/test_linalg.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,3 +996,36 @@ def test_slogdet_specialization():
996996
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
997997
nodes = f.maker.fgraph.apply_nodes
998998
assert not any(isinstance(node.op, SLogDet) for node in nodes)
999+
1000+
1001+
def test_det_triangular():
1002+
x = pt.matrix("x")
1003+
x_triangular = pt.linalg.cholesky(x)
1004+
z = pt.linalg.det(x_triangular)
1005+
1006+
# Rewrite Test
1007+
f_rewritten = function([x], z, mode="FAST_RUN")
1008+
1009+
nodes = f_rewritten.maker.fgraph.apply_nodes
1010+
assert not any(isinstance(node.op, Det) for node in nodes)
1011+
1012+
# Numeric test
1013+
x_test = np.random.rand(10, 10).astype(config.floatX)
1014+
x_psd = np.dot(x_test, x_test.T)
1015+
x_triangular = np.linalg.cholesky(x_psd)
1016+
det_val = np.linalg.det(x_triangular)
1017+
rewritten_val = f_rewritten(x_psd)
1018+
assert_allclose(
1019+
det_val,
1020+
rewritten_val,
1021+
atol=1e-3 if config.floatX == "float32" else 1e-8,
1022+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
1023+
)
1024+
1025+
# Case where rewrite should not be applied
1026+
y = pt.matrix("y")
1027+
z = pt.linalg.det(y)
1028+
1029+
f_rewritten = function([y], z, mode="FAST_RUN")
1030+
nodes = f_rewritten.maker.fgraph.apply_nodes
1031+
assert any(isinstance(node.op, Det) for node in nodes)

0 commit comments

Comments
 (0)