@@ -778,6 +778,7 @@ def perform(self, node, inputs, outputs):
778778
779779class SolveContinuousLyapunov (Op ):
780780 __props__ = ()
781+ gufunc_signature = "(m,m),(m,m)->(m,m)"
781782
782783 def make_node (self , A , B ):
783784 A = as_tensor_variable (A )
@@ -814,6 +815,8 @@ def grad(self, inputs, output_grads):
814815
815816
816817class BilinearSolveDiscreteLyapunov (Op ):
818+ gufunc_signature = "(m,m),(m,m)->(m,m)"
819+
817820 def make_node (self , A , B ):
818821 A = as_tensor_variable (A )
819822 B = as_tensor_variable (B )
@@ -849,84 +852,102 @@ def grad(self, inputs, output_grads):
849852 return [A_bar , Q_bar ]
850853
851854
852- _solve_continuous_lyapunov = SolveContinuousLyapunov ()
853- _solve_bilinear_direct_lyapunov = cast (typing .Callable , BilinearSolveDiscreteLyapunov ())
855+ _solve_continuous_lyapunov = Blockwise (SolveContinuousLyapunov ())
856+ _solve_bilinear_direct_lyapunov = cast (
857+ typing .Callable , Blockwise (BilinearSolveDiscreteLyapunov ())
858+ )
854859
855860
856- def _direct_solve_discrete_lyapunov (A : "TensorLike" , Q : "TensorLike" ) -> TensorVariable :
857- A_ = as_tensor_variable (A )
858- Q_ = as_tensor_variable (Q )
861+ def _direct_solve_discrete_lyapunov (
862+ A : TensorVariable , Q : TensorVariable
863+ ) -> TensorVariable :
864+ # By default kron acts on tensors, but we need a vectorized version over matrices for this function
865+ vec_kron = pt .vectorize (kron , "(m,n),(o,p)->(q,r)" )
859866
860- if "complex" in A_ .type .dtype :
861- AA = kron ( A_ , A_ .conj ())
867+ if A .type .dtype . startswith ( "complex" ) :
868+ AxA = vec_kron ( A , A .conj ())
862869 else :
863- AA = kron (A_ , A_ )
870+ AxA = vec_kron (A , A )
871+
872+ eye = pt .eye (AxA .shape [- 1 ])
873+ q_shape = pt .concatenate ([Q .shape [:- 2 ], [- 1 ]])
874+
875+ vec_Q = Q .reshape (q_shape )
876+ vec_X = solve (eye - AxA , vec_Q , b_ndim = 1 )
864877
865- X = solve (pt .eye (AA .shape [0 ]) - AA , Q_ .ravel ())
866- return cast (TensorVariable , reshape (X , Q_ .shape ))
878+ return cast (TensorVariable , reshape (vec_X , A .shape ))
867879
868880
869881def solve_discrete_lyapunov (
870- A : "TensorLike" , Q : "TensorLike" , method : Literal ["direct" , "bilinear" ] = "direct"
882+ A : TensorVariable ,
883+ Q : TensorVariable ,
884+ method : Literal ["direct" , "bilinear" ] = "direct" ,
871885) -> TensorVariable :
872886 """Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`.
873887
874888 Parameters
875889 ----------
876- A
877- Square matrix of shape N x N; must have the same shape as Q
878- Q
879- Square matrix of shape N x N; must have the same shape as A
880- method
881- Solver method used, one of ``"direct"`` or ``"bilinear"``. ``"direct"``
882- solves the problem directly via matrix inversion. This has a pure
883- PyTensor implementation and can thus be cross-compiled to supported
884- backends, and should be preferred when ``N`` is not large. The direct
885- method scales poorly with the size of ``N``, and the bilinear can be
890+ A: TensorVariable
891+ Square matrix of shape N x N
892+ Q: TensorVariable
893+ Square matrix of shape N x N
894+ method: str, one of ``"direct"`` or ``"bilinear"``
895+ Solver method used, . ``"direct"`` solves the problem directly via matrix inversion. This has a pure
896+ PyTensor implementation and can thus be cross-compiled to supported backends, and should be preferred when
897+ ``N`` is not large. The direct method scales poorly with the size of ``N``, and the bilinear can be
886898 used in these cases.
887899
888900 Returns
889901 -------
890- Square matrix of shape ``N x N``, representing the solution to the
891- Lyapunov equation
902+ X: TensorVariable
903+ Square matrix of shape ``N x N``. Solution to the Lyapunov equation
892904
893905 """
894906 if method not in ["direct" , "bilinear" ]:
895907 raise ValueError (
896908 f'Parameter "method" must be one of "direct" or "bilinear", found { method } '
897909 )
898910
911+ A = as_tensor_variable (A )
912+ Q = as_tensor_variable (Q )
913+
899914 if method == "direct" :
900915 return _direct_solve_discrete_lyapunov (A , Q )
916+
901917 if method == "bilinear" :
902918 return cast (TensorVariable , _solve_bilinear_direct_lyapunov (A , Q ))
903919
904920
905- def solve_continuous_lyapunov (A : "TensorLike" , Q : "TensorLike" ) -> TensorVariable :
906- """Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
921+ def solve_continuous_lyapunov (A : TensorVariable , Q : TensorVariable ) -> TensorVariable :
922+ """
923+ Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
907924
908925 Parameters
909926 ----------
910- A
911- Square matrix of shape ``N x N``; must have the same shape as `Q` .
912- Q
913- Square matrix of shape ``N x N``; must have the same shape as `A` .
927+ A: TensorVariable
928+ Square matrix of shape ``N x N``.
929+ Q: TensorVariable
930+ Square matrix of shape ``N x N``.
914931
915932 Returns
916933 -------
917- Square matrix of shape ``N x N``, representing the solution to the
918- Lyapunov equation
934+ X: TensorVariable
935+ Square matrix of shape ``N x N``
919936
920937 """
921938
922939 return cast (TensorVariable , _solve_continuous_lyapunov (A , Q ))
923940
924941
925942class SolveDiscreteARE (pt .Op ):
926- __props__ = ("enforce_Q_symmetric" ,)
943+ __props__ = ("enforce_Q_symmetric" , "use_bilinear_lyapunov" )
944+ gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"
927945
928- def __init__ (self , enforce_Q_symmetric = False ):
946+ def __init__ (
947+ self , enforce_Q_symmetric : bool = False , use_bilinear_lyapunov : bool = True
948+ ):
929949 self .enforce_Q_symmetric = enforce_Q_symmetric
950+ self .use_bilinear_lyapunov = use_bilinear_lyapunov
930951
931952 def make_node (self , A , B , Q , R ):
932953 A = as_tensor_variable (A )
@@ -961,13 +982,20 @@ def grad(self, inputs, output_grads):
961982 X = self (A , B , Q , R )
962983
963984 K_inner = R + pt .linalg .matrix_dot (B .T , X , B )
964- K_inner_inv = pt .linalg .solve (K_inner , pt .eye (R .shape [0 ]))
965- K = matrix_dot (K_inner_inv , B .T , X , A )
985+
986+ # K_inner is guaranteed to be symmetric, because X and R are symmetric
987+ K_inner_inv_BT = pt .linalg .solve (K_inner , B .T , assume_a = "sym" )
988+ K = matrix_dot (K_inner_inv_BT , X , A )
966989
967990 A_tilde = A - B .dot (K )
968991
969992 dX_symm = 0.5 * (dX + dX .T )
970- S = solve_discrete_lyapunov (A_tilde , dX_symm ).astype (dX .type .dtype )
993+ method : Literal ["bilinear" , "direct" ] = (
994+ "bilinear" if self .use_bilinear_lyapunov else "direct"
995+ )
996+ S = solve_discrete_lyapunov (A_tilde , dX_symm , method = method ).astype (
997+ dX .type .dtype
998+ )
971999
9721000 A_bar = 2 * matrix_dot (X , A_tilde , S )
9731001 B_bar = - 2 * matrix_dot (X , A_tilde , S , K .T )
@@ -977,30 +1005,43 @@ def grad(self, inputs, output_grads):
9771005 return [A_bar , B_bar , Q_bar , R_bar ]
9781006
9791007
980- def solve_discrete_are (A , B , Q , R , enforce_Q_symmetric = False ) -> TensorVariable :
1008+ def solve_discrete_are (
1009+ A : TensorVariable ,
1010+ B : TensorVariable ,
1011+ Q : TensorVariable ,
1012+ R : TensorVariable ,
1013+ enforce_Q_symmetric : bool = False ,
1014+ use_bilinear_lyapunov : bool = True ,
1015+ ) -> TensorVariable :
9811016 """
9821017 Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
9831018
9841019 Parameters
9851020 ----------
986- A: ArrayLike
1021+ A: TensorVariable
9871022 Square matrix of shape M x M
988- B: ArrayLike
1023+ B: TensorVariable
9891024 Square matrix of shape M x M
990- Q: ArrayLike
1025+ Q: TensorVariable
9911026 Symmetric square matrix of shape M x M
992- R: ArrayLike
1027+ R: TensorVariable
9931028 Square matrix of shape N x N
9941029 enforce_Q_symmetric: bool
9951030 If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
1031+ use_bilinear_lyapunov: bool
1032+ If True, the bilinear method is used to solve a discrete Lyapunov equation when computing the gradients of
1033+ the ARE. If False, the direct method is used instead. See the docstring for ``solve_discrete_lyapunov`` for
1034+ details.
9961035
9971036 Returns
9981037 -------
999- X: pt.matrix
1038+ X: TensorVariable
10001039 Square matrix of shape M x M, representing the solution to the DARE
10011040 """
10021041
1003- return cast (TensorVariable , SolveDiscreteARE (enforce_Q_symmetric )(A , B , Q , R ))
1042+ return cast (
1043+ TensorVariable , Blockwise (SolveDiscreteARE (enforce_Q_symmetric ))(A , B , Q , R )
1044+ )
10041045
10051046
10061047def _largest_common_dtype (tensors : typing .Sequence [TensorVariable ]) -> np .dtype :
0 commit comments