1
1
import logging
2
+ import typing
2
3
import warnings
3
4
from typing import TYPE_CHECKING , Literal , Union
4
5
12
13
from pytensor .tensor import as_tensor_variable
13
14
from pytensor .tensor import basic as at
14
15
from pytensor .tensor import math as atm
16
+ from pytensor .tensor .nlinalg import matrix_dot
15
17
from pytensor .tensor .shape import reshape
16
18
from pytensor .tensor .type import matrix , tensor , vector
17
19
from pytensor .tensor .var import TensorVariable
@@ -321,9 +323,6 @@ def L_op(self, inputs, outputs, output_gradients):
321
323
return res
322
324
323
325
324
- solvetriangular = SolveTriangular ()
325
-
326
-
327
326
def solve_triangular (
328
327
a : TensorVariable ,
329
328
b : TensorVariable ,
@@ -397,9 +396,6 @@ def perform(self, node, inputs, outputs):
397
396
)
398
397
399
398
400
- solve = Solve ()
401
-
402
-
403
399
def solve (a , b , assume_a = "gen" , lower = False , check_finite = True ):
404
400
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
405
401
@@ -748,13 +744,9 @@ def grad(self, inputs, output_grads):
748
744
749
745
750
746
_solve_continuous_lyapunov = SolveContinuousLyapunov ()
751
- _solve_bilinear_direct_lyapunov = BilinearSolveDiscreteLyapunov ()
752
-
753
-
754
- def iscomplexobj (x ):
755
- type_ = x .type
756
- dtype = type_ .dtype
757
- return "complex" in dtype
747
+ _solve_bilinear_direct_lyapunov = typing .cast (
748
+ typing .Callable , BilinearSolveDiscreteLyapunov ()
749
+ )
758
750
759
751
760
752
def _direct_solve_discrete_lyapunov (A : "TensorLike" , Q : "TensorLike" ) -> TensorVariable :
@@ -767,7 +759,7 @@ def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorV
767
759
AA = kron (A_ , A_ )
768
760
769
761
X = solve (pt .eye (AA .shape [0 ]) - AA , Q_ .ravel ())
770
- return reshape (X , Q_ .shape )
762
+ return typing . cast ( TensorVariable , reshape (X , Q_ .shape ) )
771
763
772
764
773
765
def solve_discrete_lyapunov (
@@ -803,7 +795,7 @@ def solve_discrete_lyapunov(
803
795
if method == "direct" :
804
796
return _direct_solve_discrete_lyapunov (A , Q )
805
797
if method == "bilinear" :
806
- return _solve_bilinear_direct_lyapunov (A , Q )
798
+ return typing . cast ( TensorVariable , _solve_bilinear_direct_lyapunov (A , Q ) )
807
799
808
800
809
801
def solve_continuous_lyapunov (A : "TensorLike" , Q : "TensorLike" ) -> TensorVariable :
@@ -823,7 +815,90 @@ def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariabl
823
815
824
816
"""
825
817
826
- return _solve_continuous_lyapunov (A , Q )
818
+ return typing .cast (TensorVariable , _solve_continuous_lyapunov (A , Q ))
819
+
820
+
821
+ class SolveDiscreteARE (pt .Op ):
822
+ __props__ = ("enforce_Q_symmetric" ,)
823
+
824
+ def __init__ (self , enforce_Q_symmetric = False ):
825
+ self .enforce_Q_symmetric = enforce_Q_symmetric
826
+
827
+ def make_node (self , A , B , Q , R ):
828
+ A = as_tensor_variable (A )
829
+ B = as_tensor_variable (B )
830
+ Q = as_tensor_variable (Q )
831
+ R = as_tensor_variable (R )
832
+
833
+ out_dtype = pytensor .scalar .upcast (A .dtype , B .dtype , Q .dtype , R .dtype )
834
+ X = pytensor .tensor .matrix (dtype = out_dtype )
835
+
836
+ return pytensor .graph .basic .Apply (self , [A , B , Q , R ], [X ])
837
+
838
+ def perform (self , node , inputs , output_storage ):
839
+ A , B , Q , R = inputs
840
+ X = output_storage [0 ]
841
+
842
+ if self .enforce_Q_symmetric :
843
+ Q = 0.5 * (Q + Q .T )
844
+
845
+ X [0 ] = scipy .linalg .solve_discrete_are (A , B , Q , R ).astype (
846
+ node .outputs [0 ].type .dtype
847
+ )
848
+
849
+ def infer_shape (self , fgraph , node , shapes ):
850
+ return [shapes [0 ]]
851
+
852
+ def grad (self , inputs , output_grads ):
853
+ # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
854
+ A , B , Q , R = inputs
855
+
856
+ (dX ,) = output_grads
857
+ X = self (A , B , Q , R )
858
+
859
+ K_inner = R + pt .linalg .matrix_dot (B .T , X , B )
860
+ K_inner_inv = pt .linalg .solve (K_inner , pt .eye (R .shape [0 ]))
861
+ K = matrix_dot (K_inner_inv , B .T , X , A )
862
+
863
+ A_tilde = A - B .dot (K )
864
+
865
+ dX_symm = 0.5 * (dX + dX .T )
866
+ S = solve_discrete_lyapunov (A_tilde , dX_symm ).astype (dX .type .dtype )
867
+
868
+ A_bar = 2 * matrix_dot (X , A_tilde , S )
869
+ B_bar = - 2 * matrix_dot (X , A_tilde , S , K .T )
870
+ Q_bar = S
871
+ R_bar = matrix_dot (K , S , K .T )
872
+
873
+ return [A_bar , B_bar , Q_bar , R_bar ]
874
+
875
+
876
+ def solve_discrete_are (A , B , Q , R , enforce_Q_symmetric = False ) -> TensorVariable :
877
+ """
878
+ Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
879
+
880
+ Parameters
881
+ ----------
882
+ A: ArrayLike
883
+ Square matrix of shape M x M
884
+ B: ArrayLike
885
+ Square matrix of shape M x M
886
+ Q: ArrayLike
887
+ Symmetric square matrix of shape M x M
888
+ R: ArrayLike
889
+ Square matrix of shape N x N
890
+ enforce_Q_symmetric: bool
891
+ If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
892
+
893
+ Returns
894
+ -------
895
+ X: pt.matrix
896
+ Square matrix of shape M x M, representing the solution to the DARE
897
+ """
898
+
899
+ return typing .cast (
900
+ TensorVariable , SolveDiscreteARE (enforce_Q_symmetric )(A , B , Q , R )
901
+ )
827
902
828
903
829
904
__all__ = [
@@ -832,4 +907,8 @@ def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariabl
832
907
"eigvalsh" ,
833
908
"kron" ,
834
909
"expm" ,
910
+ "solve_discrete_lyapunov" ,
911
+ "solve_continuous_lyapunov" ,
912
+ "solve_discrete_are" ,
913
+ "solve_triangular" ,
835
914
]
0 commit comments