@@ -525,7 +525,7 @@ def recover_Q(A, X, continuous=True):
525525vec_recover_Q = np .vectorize (recover_Q , signature = "(m,m),(m,m),()->(m,m)" )
526526
527527
528- @pytest .mark .parametrize ("use_complex" , [False , True ])
528+ @pytest .mark .parametrize ("use_complex" , [False , True ], ids = [ "float" , "complex" ] )
529529@pytest .mark .parametrize ("shape" , [(5 , 5 ), (5 , 5 , 5 )], ids = ["matrix" , "batch" ])
530530@pytest .mark .parametrize ("method" , ["direct" , "bilinear" ])
531531@pytest .mark .filterwarnings ("ignore::UserWarning" )
@@ -541,7 +541,8 @@ def test_solve_discrete_lyapunov(
541541 a = pt .tensor (name = "a" , shape = shape , dtype = dtype )
542542 q = pt .tensor (name = "q" , shape = shape , dtype = dtype )
543543
544- f = function ([a , q ], solve_discrete_lyapunov (a , q , method = method ))
544+ x = solve_discrete_lyapunov (a , q , method = method )
545+ f = function ([a , q ], x )
545546
546547 A = rng .normal (size = shape )
547548 Q = rng .normal (size = shape )
@@ -551,7 +552,9 @@ def test_solve_discrete_lyapunov(
551552 np .testing .assert_allclose (Q_recovered , Q )
552553
553554 utt .verify_grad (
554- functools .partial (solve_discrete_lyapunov , method = method ), pt = [A , Q ], rng = rng
555+ functools .partial (solve_discrete_lyapunov , method = method ),
556+ pt = [A , Q ],
557+ rng = rng ,
555558 )
556559
557560
@@ -588,7 +591,6 @@ def test_solve_discrete_are_forward(add_batch_dim):
588591
589592 x = solve_discrete_are (a , b , q , r )
590593
591- # A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q
592594 def eval_fun (a , b , q , r , x ):
593595 term_1 = a .T @ x @ a
594596 term_2 = a .T @ x @ b
0 commit comments