Skip to content

Commit a43886e

Browse files
Appease ViPy (Vieira-py type checking)
1 parent 1a33bdc commit a43886e

File tree

2 files changed

+52
-18
lines changed

2 files changed

+52
-18
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,7 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
972972
return [eye_input * (non_eye_input**0.5)]
973973

974974

975-
@node_rewriter([_solve_bilinear_discrete_lyapunov])
975+
@node_rewriter([_solve_bilinear_discrete_lyapunov]) # type: ignore
976976
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
977977
"""
978978
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX

pytensor/tensor/slinalg.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,14 @@ def perform(self, node, inputs, outputs):
777777

778778

779779
class SolveContinuousLyapunov(Op):
780+
"""
781+
Solves a continuous Lyapunov equation, :math:`AX + XA^H = B`, for :math:`X.
782+
783+
Continuous time Lyapunov equations are special cases of Sylvester equations, :math:`AX + XB = C`, and can be solved
784+
efficiently using the Bartels-Stewart algorithm. For more details, see the docstring for
785+
scipy.linalg.solve_continuous_lyapunov
786+
"""
787+
780788
__props__ = ()
781789
gufunc_signature = "(m,m),(m,m)->(m,m)"
782790

@@ -815,6 +823,14 @@ def grad(self, inputs, output_grads):
815823

816824

817825
class BilinearSolveDiscreteLyapunov(Op):
826+
"""
827+
Solves a discrete lyapunov equation, :math:`AXA^H - X = Q`, for :math:`X.
828+
829+
The solution is computed by first transforming the discrete-time problem into a continuous-time form. The continuous
830+
time lyapunov is a special case of a Sylvester equation, and can be efficiently solved. For more details, see the
831+
docstring for scipy.linalg.solve_discrete_lyapunov
832+
"""
833+
818834
gufunc_signature = "(m,m),(m,m)->(m,m)"
819835

820836
def make_node(self, A, B):
@@ -861,7 +877,17 @@ def grad(self, inputs, output_grads):
861877
)
862878

863879

864-
def _direct_solve_discrete_lyapunov(A, Q) -> TensorVariable:
880+
def _direct_solve_discrete_lyapunov(
881+
A: TensorVariable, Q: TensorVariable
882+
) -> TensorVariable:
883+
r"""
884+
Directly solve the discrete Lyapunov equation :math:`A X A^H - X = Q` using the kronecker method of Magnus and
885+
Neudecker.
886+
887+
This involves constructing and inverting an intermediate matrix :math:`A \otimes A`, with shape :math:`N^2 x N^2`.
888+
As a result, this method scales poorly with the size of :math:`N`, and should be avoided for large :math:`N`.
889+
"""
890+
865891
if A.type.dtype.startswith("complex"):
866892
AxA = kron(A, A.conj())
867893
else:
@@ -876,17 +902,17 @@ def _direct_solve_discrete_lyapunov(A, Q) -> TensorVariable:
876902

877903

878904
def solve_discrete_lyapunov(
879-
A: TensorVariable,
880-
Q: TensorVariable,
905+
A: TensorLike,
906+
Q: TensorLike,
881907
method: Literal["direct", "bilinear"] = "bilinear",
882908
) -> TensorVariable:
883909
"""Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`.
884910
885911
Parameters
886912
----------
887-
A: TensorVariable
913+
A: TensorLike
888914
Square matrix of shape N x N
889-
Q: TensorVariable
915+
Q: TensorLike
890916
Square matrix of shape N x N
891917
method: str, one of ``"direct"`` or ``"bilinear"``
892918
Solver method used, . ``"direct"`` solves the problem directly via matrix inversion. This has a pure
@@ -910,7 +936,8 @@ def solve_discrete_lyapunov(
910936

911937
if method == "direct":
912938
signature = BilinearSolveDiscreteLyapunov.gufunc_signature
913-
return pt.vectorize(_direct_solve_discrete_lyapunov, signature=signature)(A, Q)
939+
X = pt.vectorize(_direct_solve_discrete_lyapunov, signature=signature)(A, Q)
940+
return cast(TensorVariable, X)
914941

915942
elif method == "bilinear":
916943
return cast(TensorVariable, _solve_bilinear_discrete_lyapunov(A, Q))
@@ -919,15 +946,15 @@ def solve_discrete_lyapunov(
919946
raise ValueError(f"Unknown method {method}")
920947

921948

922-
def solve_continuous_lyapunov(A: TensorVariable, Q: TensorVariable) -> TensorVariable:
949+
def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
923950
"""
924951
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
925952
926953
Parameters
927954
----------
928-
A: TensorVariable
955+
A: TensorLike
929956
Square matrix of shape ``N x N``.
930-
Q: TensorVariable
957+
Q: TensorLike
931958
Square matrix of shape ``N x N``.
932959
933960
Returns
@@ -1002,24 +1029,31 @@ def grad(self, inputs, output_grads):
10021029

10031030

10041031
def solve_discrete_are(
1005-
A: TensorVariable,
1006-
B: TensorVariable,
1007-
Q: TensorVariable,
1008-
R: TensorVariable,
1032+
A: TensorLike,
1033+
B: TensorLike,
1034+
Q: TensorLike,
1035+
R: TensorLike,
10091036
enforce_Q_symmetric: bool = False,
10101037
) -> TensorVariable:
10111038
"""
10121039
Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
10131040
1041+
Discrete-time Algebraic Riccati equations arise in the context of optimal control and filtering problems, as the
1042+
solution to Linear-Quadratic Regulators (LQR), Linear-Quadratic-Guassian (LQG) control problems, and as the
1043+
steady-state covariance of the Kalman Filter.
1044+
1045+
Such problems typically have many solutions, but we are generally only interested in the unique *stabilizing*
1046+
solution. This stable solution, if it exists, will be returned by this function.
1047+
10141048
Parameters
10151049
----------
1016-
A: TensorVariable
1050+
A: TensorLike
10171051
Square matrix of shape M x M
1018-
B: TensorVariable
1052+
B: TensorLike
10191053
Square matrix of shape M x M
1020-
Q: TensorVariable
1054+
Q: TensorLike
10211055
Symmetric square matrix of shape M x M
1022-
R: TensorVariable
1056+
R: TensorLike
10231057
Square matrix of shape N x N
10241058
enforce_Q_symmetric: bool
10251059
If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry

0 commit comments

Comments
 (0)