Skip to content

Commit fc29a91

Browse files
committed
Added eig rewrite for diagonal matrix
1 parent 231a977 commit fc29a91

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
MatrixPinv,
3535
SLogDet,
3636
det,
37+
eig,
3738
inv,
3839
kron,
3940
pinv,
@@ -1013,3 +1014,70 @@ def slogdet_specialization(fgraph, node):
10131014
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
10141015
}
10151016
return replacements
1017+
1018+
1019+
@register_canonicalize
1020+
@register_stabilize
1021+
@node_rewriter([eig])
1022+
def rewrite_eig_diag(fgraph, node):
1023+
"""
1024+
This rewrite takes advantage of the fact that for a diagonal matrix, the eigenvalues are simply the diagonal elements and the eigenvectors are the identity matrix.
1025+
1026+
The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices
1027+
that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to
1028+
make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar,
1029+
vector or a matrix.
1030+
1031+
Parameters
1032+
----------
1033+
fgraph: FunctionGraph
1034+
Function graph being optimized
1035+
node: Apply
1036+
Node of the function graph to be optimized
1037+
1038+
Returns
1039+
-------
1040+
list of Variable, optional
1041+
List of optimized variables, or None if no optimization was performed
1042+
"""
1043+
inputs = node.inputs[0]
1044+
1045+
# Check for use of pt.diag first
1046+
if (
1047+
inputs.owner
1048+
and isinstance(inputs.owner.op, AllocDiag)
1049+
and AllocDiag.is_offset_zero(inputs.owner)
1050+
):
1051+
eigval_rewritten = pt.diag(inputs)
1052+
eigvec_rewritten = pt.eye(inputs.shape[-1])
1053+
return [eigval_rewritten, eigvec_rewritten]
1054+
1055+
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
1056+
inputs_or_none = _find_diag_from_eye_mul(inputs)
1057+
if inputs_or_none is None:
1058+
return None
1059+
1060+
eye_input, non_eye_inputs = inputs_or_none
1061+
1062+
# Dealing with only one other input
1063+
if len(non_eye_inputs) != 1:
1064+
return None
1065+
1066+
eye_input, non_eye_input = eye_input, non_eye_inputs[0]
1067+
# eigval_rewritten = pt.diag(non_eye_input)
1068+
eigvec_rewritten = eye_input
1069+
1070+
# Checking if original x was scalar/vector/matrix
1071+
if non_eye_input.type.broadcastable[-2:] == (True, True):
1072+
# For scalar
1073+
eigval_rewritten = pt.full(
1074+
(eye_input.shape[0],), non_eye_input.squeeze(axis=(-1, -2))
1075+
)
1076+
elif non_eye_input.type.broadcastable[-2:] == (False, False):
1077+
# For Matrix
1078+
eigval_rewritten = pt.diag(non_eye_input)
1079+
else:
1080+
# For vector
1081+
eigval_rewritten = non_eye_input.squeeze()
1082+
1083+
return [eigval_rewritten, eigvec_rewritten]

tests/tensor/rewriting/test_linalg.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pytensor.tensor.nlinalg import (
1919
SVD,
2020
Det,
21+
Eig,
2122
KroneckerProduct,
2223
MatrixInverse,
2324
MatrixPinv,
@@ -996,3 +997,81 @@ def test_slogdet_specialization():
996997
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
997998
nodes = f.maker.fgraph.apply_nodes
998999
assert not any(isinstance(node.op, SLogDet) for node in nodes)
1000+
1001+
1002+
@pytest.mark.parametrize(
1003+
"shape",
1004+
[(), (7,), (1, 7), (7, 1), (7, 7)],
1005+
ids=["scalar", "vector", "row_vec", "col_vec", "matrix"],
1006+
)
1007+
def test_eig_diag_from_eye_mul(shape):
1008+
# Initializing x based on scalar/vector/matrix
1009+
x = pt.tensor("x", shape=shape)
1010+
y = pt.eye(7) * x
1011+
1012+
# Calculating eigval and eigvec using pt.linalg.eig
1013+
eigval, eigvec = pt.linalg.eig(y)
1014+
1015+
# REWRITE TEST
1016+
f_rewritten = function([x], [eigval, eigvec], mode="FAST_RUN")
1017+
nodes = f_rewritten.maker.fgraph.apply_nodes
1018+
1019+
assert not any(
1020+
isinstance(node.op, Eig) or isinstance(getattr(node.op, "core_op", None), Eig)
1021+
for node in nodes
1022+
)
1023+
1024+
# NUMERIC VALUE TEST
1025+
if len(shape) == 0:
1026+
x_test = np.array(np.random.rand()).astype(config.floatX)
1027+
elif len(shape) == 1:
1028+
x_test = np.random.rand(*shape).astype(config.floatX)
1029+
else:
1030+
x_test = np.random.rand(*shape).astype(config.floatX)
1031+
1032+
x_test_matrix = np.eye(7) * x_test
1033+
eigval, eigvec = np.linalg.eig(x_test_matrix)
1034+
rewritten_eigval, rewritten_eigvec = f_rewritten(x_test)
1035+
1036+
assert_allclose(
1037+
eigval,
1038+
rewritten_eigval,
1039+
atol=1e-3 if config.floatX == "float32" else 1e-8,
1040+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
1041+
)
1042+
assert_allclose(
1043+
eigvec,
1044+
rewritten_eigvec,
1045+
atol=1e-3 if config.floatX == "float32" else 1e-8,
1046+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
1047+
)
1048+
1049+
1050+
def test_eig_diag_from_diag():
1051+
x = pt.tensor("x", shape=(None,))
1052+
x_diag = pt.diag(x)
1053+
eigval, eigvec = pt.linalg.eig(x_diag)
1054+
1055+
# REWRITE TEST
1056+
f_rewritten = function([x], [eigval, eigvec], mode="FAST_RUN")
1057+
f_rewritten.dprint()
1058+
nodes = f_rewritten.maker.fgraph.apply_nodes
1059+
assert not any(isinstance(node.op, Eig) for node in nodes)
1060+
1061+
# NUMERIC VALUE TEST
1062+
x_test = np.random.rand(7).astype(config.floatX)
1063+
x_test_matrix = np.eye(7) * x_test
1064+
eigval, eigvec = np.linalg.eig(x_test_matrix)
1065+
rewritten_eigval, rewritten_eigvec = f_rewritten(x_test)
1066+
assert_allclose(
1067+
eigval,
1068+
rewritten_eigval,
1069+
atol=1e-3 if config.floatX == "float32" else 1e-8,
1070+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
1071+
)
1072+
assert_allclose(
1073+
eigvec,
1074+
rewritten_eigvec,
1075+
atol=1e-3 if config.floatX == "float32" else 1e-8,
1076+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
1077+
)

0 commit comments

Comments
 (0)