@@ -797,9 +797,7 @@ def make_node(self, A, B):
797797 def perform (self , node , inputs , output_storage ):
798798 (A , B ) = inputs
799799 X = output_storage [0 ]
800- out_dtype = node .outputs [0 ].type .dtype
801-
802- X [0 ] = scipy .linalg .solve_continuous_lyapunov (A , B ).astype (out_dtype )
800+ X [0 ] = scipy .linalg .solve_continuous_lyapunov (A , B )
803801
804802 def infer_shape (self , fgraph , node , shapes ):
805803 return [shapes [0 ]]
@@ -820,6 +818,30 @@ def grad(self, inputs, output_grads):
820818 return [A_bar , Q_bar ]
821819
822820
821+ _solve_continuous_lyapunov = Blockwise (SolveContinuousLyapunov ())
822+
823+
824+ def solve_continuous_lyapunov (A : TensorLike , Q : TensorLike ) -> TensorVariable :
825+ """
826+ Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
827+
828+ Parameters
829+ ----------
830+ A: TensorLike
831+ Square matrix of shape ``N x N``.
832+ Q: TensorLike
833+ Square matrix of shape ``N x N``.
834+
835+ Returns
836+ -------
837+ X: TensorVariable
838+ Square matrix of shape ``N x N``
839+
840+ """
841+
842+ return cast (TensorVariable , _solve_continuous_lyapunov (A , Q ))
843+
844+
823845class BilinearSolveDiscreteLyapunov (Op ):
824846 """
825847 Solves a discrete lyapunov equation, :math:`AXA^H - X = Q`, for :math:`X.
@@ -844,10 +866,7 @@ def perform(self, node, inputs, output_storage):
844866 (A , B ) = inputs
845867 X = output_storage [0 ]
846868
847- out_dtype = node .outputs [0 ].type .dtype
848- X [0 ] = scipy .linalg .solve_discrete_lyapunov (A , B , method = "bilinear" ).astype (
849- out_dtype
850- )
869+ X [0 ] = scipy .linalg .solve_discrete_lyapunov (A , B , method = "bilinear" )
851870
852871 def infer_shape (self , fgraph , node , shapes ):
853872 return [shapes [0 ]]
@@ -869,6 +888,9 @@ def grad(self, inputs, output_grads):
869888 return [A_bar , Q_bar ]
870889
871890
891+ _bilinear_solve_discrete_lyapunov = Blockwise (BilinearSolveDiscreteLyapunov ())
892+
893+
872894def _direct_solve_discrete_lyapunov (
873895 A : TensorVariable , Q : TensorVariable
874896) -> TensorVariable :
@@ -932,33 +954,12 @@ def solve_discrete_lyapunov(
932954 return cast (TensorVariable , X )
933955
934956 elif method == "bilinear" :
935- return cast (TensorVariable , Blockwise ( BilinearSolveDiscreteLyapunov ()) (A , Q ))
957+ return cast (TensorVariable , _bilinear_solve_discrete_lyapunov (A , Q ))
936958
937959 else :
938960 raise ValueError (f"Unknown method { method } " )
939961
940962
941- def solve_continuous_lyapunov (A : TensorLike , Q : TensorLike ) -> TensorVariable :
942- """
943- Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
944-
945- Parameters
946- ----------
947- A: TensorLike
948- Square matrix of shape ``N x N``.
949- Q: TensorLike
950- Square matrix of shape ``N x N``.
951-
952- Returns
953- -------
954- X: TensorVariable
955- Square matrix of shape ``N x N``
956-
957- """
958-
959- return cast (TensorVariable , Blockwise (SolveContinuousLyapunov ())(A , Q ))
960-
961-
962963class SolveDiscreteARE (Op ):
963964 __props__ = ("enforce_Q_symmetric" ,)
964965 gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"
@@ -987,9 +988,7 @@ def perform(self, node, inputs, output_storage):
987988 if self .enforce_Q_symmetric :
988989 Q = 0.5 * (Q + Q .T )
989990
990- X [0 ] = scipy .linalg .solve_discrete_are (A , B , Q , R ).astype (
991- node .outputs [0 ].type .dtype
992- )
991+ X [0 ] = scipy .linalg .solve_discrete_are (A , B , Q , R )
993992
994993 def infer_shape (self , fgraph , node , shapes ):
995994 return [shapes [0 ]]
0 commit comments