@@ -853,26 +853,20 @@ def grad(self, inputs, output_grads):
853853
854854
855855_solve_continuous_lyapunov = Blockwise (SolveContinuousLyapunov ())
856- _solve_bilinear_direct_lyapunov = cast (
856+ _solve_bilinear_discrete_lyapunov = cast (
857857 typing .Callable , Blockwise (BilinearSolveDiscreteLyapunov ())
858858)
859859
860860
861- def _direct_solve_discrete_lyapunov (
862- A : TensorVariable , Q : TensorVariable
863- ) -> TensorVariable :
864- # By default kron acts on tensors, but we need a vectorized version over matrices for this function
865- vec_kron = pt .vectorize (kron , "(m,n),(o,p)->(q,r)" )
866-
861+ def _direct_solve_discrete_lyapunov (A , Q ) -> TensorVariable :
867862 if A .type .dtype .startswith ("complex" ):
868- AxA = vec_kron (A , A .conj ())
863+ AxA = kron (A , A .conj ())
869864 else :
870- AxA = vec_kron (A , A )
865+ AxA = kron (A , A )
871866
872867 eye = pt .eye (AxA .shape [- 1 ])
873- q_shape = pt .concatenate ([Q .shape [:- 2 ], [- 1 ]])
874868
875- vec_Q = Q .reshape ( q_shape )
869+ vec_Q = Q .ravel ( )
876870 vec_X = solve (eye - AxA , vec_Q , b_ndim = 1 )
877871
878872 return cast (TensorVariable , reshape (vec_X , A .shape ))
@@ -912,10 +906,11 @@ def solve_discrete_lyapunov(
912906 Q = as_tensor_variable (Q )
913907
914908 if method == "direct" :
915- return _direct_solve_discrete_lyapunov (A , Q )
909+ signature = BilinearSolveDiscreteLyapunov .gufunc_signature
910+ return pt .vectorize (_direct_solve_discrete_lyapunov , signature = signature )(A , Q )
916911
917912 elif method == "bilinear" :
918- return cast (TensorVariable , _solve_bilinear_direct_lyapunov (A , Q ))
913+ return cast (TensorVariable , _solve_bilinear_discrete_lyapunov (A , Q ))
919914
920915 else :
921916 raise ValueError (f"Unknown method { method } " )
0 commit comments