@@ -543,12 +543,33 @@ def test_solve_discrete_lyapunov(
543543 x = solve_discrete_lyapunov (a , q , method = method )
544544 f = function ([a , q ], x )
545545
546- A = rng .normal (size = shape )
547- Q = rng .normal (size = shape )
546+ A = rng .normal (size = shape ). astype ( dtype )
547+ Q = rng .normal (size = shape ). astype ( dtype )
548548
549549 X = f (A , Q )
550550 Q_recovered = vec_recover_Q (A , X , continuous = False )
551- np .testing .assert_allclose (Q_recovered , Q )
551+
552+ atol = rtol = 1e-4 if config .floatX == "float32" else 1e-8
553+ np .testing .assert_allclose (Q_recovered , Q , atol = atol , rtol = rtol )
554+
555+
556+ @pytest .mark .parametrize ("use_complex" , [False , True ], ids = ["float" , "complex" ])
557+ @pytest .mark .parametrize ("shape" , [(5 , 5 ), (5 , 5 , 5 )], ids = ["matrix" , "batch" ])
558+ @pytest .mark .parametrize ("method" , ["direct" , "bilinear" ])
559+ def test_solve_discrete_lyapunov_gradient (
560+ use_complex , shape : tuple [int ], method : Literal ["direct" , "bilinear" ]
561+ ):
562+ if config .floatX == "float32" :
563+ pytest .skip (reason = "Not enough precision in float32 to get a good gradient" )
564+
565+ rng = np .random .default_rng (utt .fetch_seed ())
566+ dtype = config .floatX
567+ if use_complex :
568+ precision = int (dtype [- 2 :]) # 64 or 32
569+ dtype = f"complex{ int (2 * precision )} "
570+
571+ A = rng .normal (size = shape ).astype (dtype )
572+ Q = rng .normal (size = shape ).astype (dtype )
552573
553574 utt .verify_grad (
554575 functools .partial (solve_discrete_lyapunov , method = method ),
@@ -564,13 +585,25 @@ def test_solve_continuous_lyapunov(shape: tuple[int]):
564585 q = pt .tensor (name = "q" , shape = shape )
565586 f = function ([a , q ], [solve_continuous_lyapunov (a , q )])
566587
567- A = rng .normal (size = shape )
568- Q = rng .normal (size = shape )
588+ A = rng .normal (size = shape ). astype ( config . floatX )
589+ Q = rng .normal (size = shape ). astype ( config . floatX )
569590 X = f (A , Q )
570591
571592 Q_recovered = vec_recover_Q (A , X , continuous = True )
572593
573- np .testing .assert_allclose (Q_recovered .squeeze (), Q )
594+ atol = rtol = 1e-2 if config .floatX == "float32" else 1e-8
595+ np .testing .assert_allclose (Q_recovered .squeeze (), Q , atol = atol , rtol = rtol )
596+
597+
598+ @pytest .mark .parametrize ("shape" , [(5 , 5 ), (5 , 5 , 5 )], ids = ["matrix" , "batched" ])
599+ def test_solve_continuous_lyapunov_grad (shape : tuple [int ]):
600+ if config .floatX == "float32" :
601+ pytest .skip (reason = "Not enough precision in float32 to get a good gradient" )
602+
603+ rng = np .random .default_rng (utt .fetch_seed ())
604+ A = rng .normal (size = shape ).astype (config .floatX )
605+ Q = rng .normal (size = shape ).astype (config .floatX )
606+
574607 utt .verify_grad (solve_continuous_lyapunov , pt = [A , Q ], rng = rng )
575608
576609
0 commit comments