Skip to content

Commit bdd1679

Browse files
committed
added rewrite for eig when input matrix is identity
1 parent fc29a91 commit bdd1679

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1016,12 +1016,46 @@ def slogdet_specialization(fgraph, node):
10161016
return replacements
10171017

10181018

1019+
@register_canonicalize
1020+
@register_stabilize
1021+
@node_rewriter([eig])
1022+
def rewrite_eig_eye(fgraph, node):
1023+
"""
1024+
This rewrite takes advantage of the fact that for any identity matrix, all the eigenvalues are 1 and the eigenvectors are the standard basis.
1025+
1026+
Parameters
1027+
----------
1028+
fgraph: FunctionGraph
1029+
Function graph being optimized
1030+
node: Apply
1031+
Node of the function graph to be optimized
1032+
1033+
Returns
1034+
-------
1035+
list of Variable, optional
1036+
List of optimized variables, or None if no optimization was performed
1037+
"""
1038+
# Check whether input to Eig is Eye and the 1's are on main diagonal
1039+
potential_eye = node.inputs[0]
1040+
if not (
1041+
potential_eye.owner
1042+
and isinstance(potential_eye.owner.op, Eye)
1043+
and getattr(potential_eye.owner.inputs[-1], "data", -1).item() == 0
1044+
):
1045+
return None
1046+
1047+
eigval_rewritten = pt.ones(potential_eye.shape[-1])
1048+
eigvec_rewritten = pt.eye(potential_eye.shape[-1])
1049+
1050+
return [eigval_rewritten, eigvec_rewritten]
1051+
1052+
10191053
@register_canonicalize
10201054
@register_stabilize
10211055
@node_rewriter([eig])
10221056
def rewrite_eig_diag(fgraph, node):
10231057
"""
1024-
This rewrite takes advantage of the fact that for a diagonal matrix, the eigenvalues are simply the diagonal elements and the eigenvectors are the identity matrix.
1058+
This rewrite takes advantage of the fact that for a diagonal matrix, the eigenvalues are simply the diagonal elements and the eigenvectors are the standard basis.
10251059
10261060
The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices
10271061
that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to

tests/tensor/rewriting/test_linalg.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,14 +1047,40 @@ def test_eig_diag_from_eye_mul(shape):
10471047
)
10481048

10491049

1050-
def test_eig_diag_from_diag():
1050+
def test_eig_eye():
1051+
x = pt.eye(10)
1052+
eigval, eigvec = pt.linalg.eig(x)
1053+
1054+
# REWRITE TEST
1055+
f_rewritten = function([], [eigval, eigvec], mode="FAST_RUN")
1056+
nodes = f_rewritten.maker.fgraph.apply_nodes
1057+
assert not any(isinstance(node.op, Eig) for node in nodes)
1058+
1059+
# NUMERIC VALUE TEST
1060+
x_test = np.eye(10)
1061+
eigval, eigvec = np.linalg.eig(x_test)
1062+
rewritten_eigval, rewritten_eigvec = f_rewritten()
1063+
assert_allclose(
1064+
eigval,
1065+
rewritten_eigval,
1066+
atol=1e-3 if config.floatX == "float32" else 1e-8,
1067+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
1068+
)
1069+
assert_allclose(
1070+
eigvec,
1071+
rewritten_eigvec,
1072+
atol=1e-3 if config.floatX == "float32" else 1e-8,
1073+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
1074+
)
1075+
1076+
1077+
def test_eig_diag():
10511078
x = pt.tensor("x", shape=(None,))
10521079
x_diag = pt.diag(x)
10531080
eigval, eigvec = pt.linalg.eig(x_diag)
10541081

10551082
# REWRITE TEST
10561083
f_rewritten = function([x], [eigval, eigvec], mode="FAST_RUN")
1057-
f_rewritten.dprint()
10581084
nodes = f_rewritten.maker.fgraph.apply_nodes
10591085
assert not any(isinstance(node.op, Eig) for node in nodes)
10601086

0 commit comments

Comments
 (0)