@@ -797,7 +797,9 @@ def make_node(self, A, B):
797797 def perform (self , node , inputs , output_storage ):
798798 (A , B ) = inputs
799799 X = output_storage [0 ]
800- X [0 ] = scipy .linalg .solve_continuous_lyapunov (A , B )
800+
801+ out_dtype = node .outputs [0 ].type .dtype
802+ X [0 ] = scipy .linalg .solve_continuous_lyapunov (A , B ).astype (out_dtype )
801803
802804 def infer_shape (self , fgraph , node , shapes ):
803805 return [shapes [0 ]]
@@ -866,7 +868,10 @@ def perform(self, node, inputs, output_storage):
866868 (A , B ) = inputs
867869 X = output_storage [0 ]
868870
869- X [0 ] = scipy .linalg .solve_discrete_lyapunov (A , B , method = "bilinear" )
871+ out_dtype = node .outputs [0 ].type .dtype
872+ X [0 ] = scipy .linalg .solve_discrete_lyapunov (A , B , method = "bilinear" ).astype (
873+ out_dtype
874+ )
870875
871876 def infer_shape (self , fgraph , node , shapes ):
872877 return [shapes [0 ]]
@@ -964,11 +969,8 @@ class SolveDiscreteARE(Op):
964969 __props__ = ("enforce_Q_symmetric" ,)
965970 gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"
966971
967- def __init__ (
968- self , enforce_Q_symmetric : bool = False , use_bilinear_lyapunov : bool = True
969- ):
972+ def __init__ (self , enforce_Q_symmetric : bool = False ):
970973 self .enforce_Q_symmetric = enforce_Q_symmetric
971- self .use_bilinear_lyapunov = use_bilinear_lyapunov
972974
973975 def make_node (self , A , B , Q , R ):
974976 A = as_tensor_variable (A )
@@ -988,7 +990,8 @@ def perform(self, node, inputs, output_storage):
988990 if self .enforce_Q_symmetric :
989991 Q = 0.5 * (Q + Q .T )
990992
991- X [0 ] = scipy .linalg .solve_discrete_are (A , B , Q , R )
993+ out_dtype = node .outputs [0 ].type .dtype
994+ X [0 ] = scipy .linalg .solve_discrete_are (A , B , Q , R ).astype (out_dtype )
992995
993996 def infer_shape (self , fgraph , node , shapes ):
994997 return [shapes [0 ]]
@@ -1000,16 +1003,16 @@ def grad(self, inputs, output_grads):
10001003 (dX ,) = output_grads
10011004 X = self (A , B , Q , R )
10021005
1003- K_inner = R + pt . linalg . matrix_dot (B .T , X , B )
1006+ K_inner = R + matrix_dot (B .T , X , B )
10041007
10051008 # K_inner is guaranteed to be symmetric, because X and R are symmetric
1006- K_inner_inv_BT = pt . linalg . solve (K_inner , B .T , assume_a = "sym" )
1009+ K_inner_inv_BT = solve (K_inner , B .T , assume_a = "sym" )
10071010 K = matrix_dot (K_inner_inv_BT , X , A )
10081011
10091012 A_tilde = A - B .dot (K )
10101013
10111014 dX_symm = 0.5 * (dX + dX .T )
1012- S = solve_discrete_lyapunov (A_tilde , dX_symm ). astype ( dX . type . dtype )
1015+ S = solve_discrete_lyapunov (A_tilde , dX_symm )
10131016
10141017 A_bar = 2 * matrix_dot (X , A_tilde , S )
10151018 B_bar = - 2 * matrix_dot (X , A_tilde , S , K .T )
0 commit comments