Skip to content

Commit ab13fe0

Browse files
Add solve_discrete_lyapunov and solve_continuous_lyapunov (#33)
Co-authored-by: jessegrabowski <[email protected]>
1 parent f203dd7 commit ab13fe0

File tree

2 files changed

+232
-8
lines changed

2 files changed

+232
-8
lines changed

pytensor/tensor/slinalg.py

Lines changed: 162 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
11
import logging
22
import warnings
3-
from typing import Union
3+
from typing import TYPE_CHECKING, Union
44

55
import numpy as np
66
import scipy.linalg
7+
from typing_extensions import Literal
78

8-
import pytensor.tensor
9+
import pytensor
10+
import pytensor.tensor as pt
911
from pytensor.graph.basic import Apply
1012
from pytensor.graph.op import Op
1113
from pytensor.tensor import as_tensor_variable
1214
from pytensor.tensor import basic as at
1315
from pytensor.tensor import math as atm
16+
from pytensor.tensor.shape import reshape
1417
from pytensor.tensor.type import matrix, tensor, vector
1518
from pytensor.tensor.var import TensorVariable
1619

1720

21+
if TYPE_CHECKING:
22+
from pytensor.tensor import TensorLike
23+
24+
1825
logger = logging.getLogger(__name__)
1926

2027

@@ -735,6 +742,159 @@ def perform(self, node, inputs, outputs):
735742

736743
expm = Expm()
737744

745+
746+
class SolveContinuousLyapunov(Op):
747+
__props__ = ()
748+
749+
def make_node(self, A, B):
750+
A = as_tensor_variable(A)
751+
B = as_tensor_variable(B)
752+
753+
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype)
754+
X = pytensor.tensor.matrix(dtype=out_dtype)
755+
756+
return pytensor.graph.basic.Apply(self, [A, B], [X])
757+
758+
def perform(self, node, inputs, output_storage):
759+
(A, B) = inputs
760+
X = output_storage[0]
761+
762+
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B)
763+
764+
def infer_shape(self, fgraph, node, shapes):
765+
return [shapes[0]]
766+
767+
def grad(self, inputs, output_grads):
768+
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
769+
# Note that they write the equation as AX + XA.H + Q = 0, while scipy uses AX + XA^H = Q,
770+
# so minor adjustments need to be made.
771+
A, Q = inputs
772+
(dX,) = output_grads
773+
774+
X = self(A, Q)
775+
S = self(A.conj().T, -dX) # Eq 31, adjusted
776+
777+
A_bar = S.dot(X.conj().T) + S.conj().T.dot(X)
778+
Q_bar = -S # Eq 29, adjusted
779+
780+
return [A_bar, Q_bar]
781+
782+
783+
class BilinearSolveDiscreteLyapunov(Op):
784+
def make_node(self, A, B):
785+
A = as_tensor_variable(A)
786+
B = as_tensor_variable(B)
787+
788+
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype)
789+
X = pytensor.tensor.matrix(dtype=out_dtype)
790+
791+
return pytensor.graph.basic.Apply(self, [A, B], [X])
792+
793+
def perform(self, node, inputs, output_storage):
794+
(A, B) = inputs
795+
X = output_storage[0]
796+
797+
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear")
798+
799+
def infer_shape(self, fgraph, node, shapes):
800+
return [shapes[0]]
801+
802+
def grad(self, inputs, output_grads):
803+
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
804+
A, Q = inputs
805+
(dX,) = output_grads
806+
807+
X = self(A, Q)
808+
809+
# Eq 41, note that it is not written as a proper Lyapunov equation
810+
S = self(A.conj().T, dX)
811+
812+
A_bar = pytensor.tensor.linalg.matrix_dot(
813+
S, A, X.conj().T
814+
) + pytensor.tensor.linalg.matrix_dot(S.conj().T, A, X)
815+
Q_bar = S
816+
return [A_bar, Q_bar]
817+
818+
819+
_solve_continuous_lyapunov = SolveContinuousLyapunov()
820+
_solve_bilinear_direct_lyapunov = BilinearSolveDiscreteLyapunov()
821+
822+
823+
def iscomplexobj(x):
824+
type_ = x.type
825+
dtype = type_.dtype
826+
return "complex" in dtype
827+
828+
829+
def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
830+
A_ = as_tensor_variable(A)
831+
Q_ = as_tensor_variable(Q)
832+
833+
if "complex" in A_.type.dtype:
834+
AA = kron(A_, A_.conj())
835+
else:
836+
AA = kron(A_, A_)
837+
838+
X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel())
839+
return reshape(X, Q_.shape)
840+
841+
842+
def solve_discrete_lyapunov(
843+
A: "TensorLike", Q: "TensorLike", method: Literal["direct", "bilinear"] = "direct"
844+
) -> TensorVariable:
845+
"""Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`.
846+
847+
Parameters
848+
----------
849+
A
850+
Square matrix of shape N x N; must have the same shape as Q
851+
Q
852+
Square matrix of shape N x N; must have the same shape as A
853+
method
854+
Solver method used, one of ``"direct"`` or ``"bilinear"``. ``"direct"``
855+
solves the problem directly via matrix inversion. This has a pure
856+
PyTensor implementation and can thus be cross-compiled to supported
857+
backends, and should be preferred when ``N`` is not large. The direct
858+
method scales poorly with the size of ``N``, and the bilinear can be
859+
used in these cases.
860+
861+
Returns
862+
-------
863+
Square matrix of shape ``N x N``, representing the solution to the
864+
Lyapunov equation
865+
866+
"""
867+
if method not in ["direct", "bilinear"]:
868+
raise ValueError(
869+
f'Parameter "method" must be one of "direct" or "bilinear", found {method}'
870+
)
871+
872+
if method == "direct":
873+
return _direct_solve_discrete_lyapunov(A, Q)
874+
if method == "bilinear":
875+
return _solve_bilinear_direct_lyapunov(A, Q)
876+
877+
878+
def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
879+
"""Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
880+
881+
Parameters
882+
----------
883+
A
884+
Square matrix of shape ``N x N``; must have the same shape as `Q`.
885+
Q
886+
Square matrix of shape ``N x N``; must have the same shape as `A`.
887+
888+
Returns
889+
-------
890+
Square matrix of shape ``N x N``, representing the solution to the
891+
Lyapunov equation
892+
893+
"""
894+
895+
return _solve_continuous_lyapunov(A, Q)
896+
897+
738898
__all__ = [
739899
"cholesky",
740900
"solve",

tests/tensor/test_slinalg.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
import itertools
33

44
import numpy as np
5-
import numpy.linalg
65
import pytest
76
import scipy
87

98
import pytensor
109
from pytensor import function, grad
11-
from pytensor import tensor as at
10+
from pytensor import tensor as pt
1211
from pytensor.configdefaults import config
1312
from pytensor.tensor.slinalg import (
1413
Cholesky,
@@ -23,6 +22,8 @@
2322
expm,
2423
kron,
2524
solve,
25+
solve_continuous_lyapunov,
26+
solve_discrete_lyapunov,
2627
solve_triangular,
2728
)
2829
from pytensor.tensor.type import dmatrix, matrix, tensor, vector
@@ -155,7 +156,7 @@ def test_eigvalsh():
155156
# We need to test None separately, as otherwise DebugMode will
156157
# complain, as this isn't a valid ndarray.
157158
b = None
158-
B = at.NoneConst
159+
B = pt.NoneConst
159160
f = function([A], eigvalsh(A, B))
160161
w = f(a)
161162
refw = scipy.linalg.eigvalsh(a, b)
@@ -215,7 +216,7 @@ def test_infer_shape(self, b_shape):
215216
rng = np.random.default_rng(utt.fetch_seed())
216217
A = matrix()
217218
b_val = np.asarray(rng.random(b_shape), dtype=config.floatX)
218-
b = at.as_tensor_variable(b_val).type()
219+
b = pt.as_tensor_variable(b_val).type()
219220
self._compile_and_check(
220221
[A, b],
221222
[solve(A, b)],
@@ -292,7 +293,7 @@ def test_infer_shape(self, b_shape):
292293
rng = np.random.default_rng(utt.fetch_seed())
293294
A = matrix()
294295
b_val = np.asarray(rng.random(b_shape), dtype=config.floatX)
295-
b = at.as_tensor_variable(b_val).type()
296+
b = pt.as_tensor_variable(b_val).type()
296297
self._compile_and_check(
297298
[A, b],
298299
[solve_triangular(A, b)],
@@ -514,7 +515,6 @@ def test_expm_grad_3():
514515

515516

516517
class TestKron(utt.InferShapeTester):
517-
518518
rng = np.random.default_rng(43)
519519

520520
def setup_method(self):
@@ -552,3 +552,67 @@ def test_numpy_2d(self):
552552
b = self.rng.random(shp1).astype(config.floatX)
553553
out = f(a, b)
554554
assert np.allclose(out, np.kron(a, b))
555+
556+
557+
def test_solve_discrete_lyapunov_via_direct_real():
558+
N = 5
559+
rng = np.random.default_rng(utt.fetch_seed())
560+
a = pt.dmatrix()
561+
q = pt.dmatrix()
562+
f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")])
563+
564+
A = rng.normal(size=(N, N))
565+
Q = rng.normal(size=(N, N))
566+
567+
X = f(A, Q)
568+
assert np.allclose(A @ X @ A.T - X + Q, 0.0)
569+
570+
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
571+
572+
573+
def test_solve_discrete_lyapunov_via_direct_complex():
574+
N = 5
575+
rng = np.random.default_rng(utt.fetch_seed())
576+
a = pt.zmatrix()
577+
q = pt.zmatrix()
578+
f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")])
579+
580+
A = rng.normal(size=(N, N)) + rng.normal(size=(N, N)) * 1j
581+
Q = rng.normal(size=(N, N))
582+
X = f(A, Q)
583+
assert np.allclose(A @ X @ A.conj().T - X + Q, 0.0)
584+
585+
# TODO: the .conj() method currently does not have a gradient; add this test when gradients are implemented.
586+
# utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
587+
588+
589+
def test_solve_discrete_lyapunov_via_bilinear():
590+
N = 5
591+
rng = np.random.default_rng(utt.fetch_seed())
592+
a = pt.dmatrix()
593+
q = pt.dmatrix()
594+
f = function([a, q], [solve_discrete_lyapunov(a, q, method="bilinear")])
595+
596+
A = rng.normal(size=(N, N))
597+
Q = rng.normal(size=(N, N))
598+
599+
X = f(A, Q)
600+
assert np.allclose(A @ X @ A.conj().T - X + Q, 0.0)
601+
602+
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
603+
604+
605+
def test_solve_continuous_lyapunov():
606+
N = 5
607+
rng = np.random.default_rng(utt.fetch_seed())
608+
a = pt.dmatrix()
609+
q = pt.dmatrix()
610+
f = function([a, q], [solve_continuous_lyapunov(a, q)])
611+
612+
A = rng.normal(size=(N, N))
613+
Q = rng.normal(size=(N, N))
614+
X = f(A, Q)
615+
616+
assert np.allclose(A @ X + X @ A.conj().T, Q)
617+
618+
utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng)

0 commit comments

Comments
 (0)