@@ -537,15 +537,22 @@ def test_solve_discrete_lyapunov(
537537 precision = int (dtype [- 2 :]) # 64 or 32
538538 dtype = f"complex{ int (2 * precision )} "
539539
540+ A1 , A2 = rng .normal (size = (2 , * shape )).astype (dtype )
541+ Q1 , Q2 = rng .normal (size = (2 , * shape )).astype (dtype )
542+
543+ if use_complex :
544+ A = A1 + 1j * A2
545+ Q = Q1 + 1j * Q2
546+ else :
547+ A = A1
548+ Q = Q1
549+
540550 a = pt .tensor (name = "a" , shape = shape , dtype = dtype )
541551 q = pt .tensor (name = "q" , shape = shape , dtype = dtype )
542552
543553 x = solve_discrete_lyapunov (a , q , method = method )
544554 f = function ([a , q ], x )
545555
546- A = rng .normal (size = shape ).astype (dtype )
547- Q = rng .normal (size = shape ).astype (dtype )
548-
549556 X = f (A , Q )
550557 Q_recovered = vec_recover_Q (A , X , continuous = False )
551558
@@ -561,15 +568,12 @@ def test_solve_discrete_lyapunov_gradient(
561568):
562569 if config .floatX == "float32" :
563570 pytest .skip (reason = "Not enough precision in float32 to get a good gradient" )
564-
565- rng = np .random .default_rng (utt .fetch_seed ())
566- dtype = config .floatX
567571 if use_complex :
568- precision = int (dtype [- 2 :]) # 64 or 32
569- dtype = f"complex{ int (2 * precision )} "
572+ pytest .skip (reason = "Complex numbers are not supported in the gradient test" )
570573
571- A = rng .normal (size = shape ).astype (dtype )
572- Q = rng .normal (size = shape ).astype (dtype )
574+ rng = np .random .default_rng (utt .fetch_seed ())
575+ A = rng .normal (size = shape ).astype (config .floatX )
576+ Q = rng .normal (size = shape ).astype (config .floatX )
573577
574578 utt .verify_grad (
575579 functools .partial (solve_discrete_lyapunov , method = method ),
@@ -579,14 +583,29 @@ def test_solve_discrete_lyapunov_gradient(
579583
580584
581585@pytest .mark .parametrize ("shape" , [(5 , 5 ), (5 , 5 , 5 )], ids = ["matrix" , "batched" ])
582- def test_solve_continuous_lyapunov (shape : tuple [int ]):
586+ @pytest .mark .parametrize ("use_complex" , [False , True ], ids = ["float" , "complex" ])
587+ def test_solve_continuous_lyapunov (shape : tuple [int ], use_complex : bool ):
583588 rng = np .random .default_rng (utt .fetch_seed ())
584- a = pt .tensor (name = "a" , shape = shape )
585- q = pt .tensor (name = "q" , shape = shape )
589+
590+ dtype = config .floatX
591+ if use_complex :
592+ precision = int (dtype [- 2 :]) # 64 or 32
593+ dtype = f"complex{ int (2 * precision )} "
594+
595+ A1 , A2 = rng .normal (size = (2 , * shape )).astype (dtype )
596+ Q1 , Q2 = rng .normal (size = (2 , * shape )).astype (dtype )
597+
598+ if use_complex :
599+ A = A1 + 1j * A2
600+ Q = Q1 + 1j * Q2
601+ else :
602+ A = A1
603+ Q = Q1
604+
605+ a = pt .tensor (name = "a" , shape = shape , dtype = dtype )
606+ q = pt .tensor (name = "q" , shape = shape , dtype = dtype )
586607 f = function ([a , q ], [solve_continuous_lyapunov (a , q )])
587608
588- A = rng .normal (size = shape ).astype (config .floatX )
589- Q = rng .normal (size = shape ).astype (config .floatX )
590609 X = f (A , Q )
591610
592611 Q_recovered = vec_recover_Q (A , X , continuous = True )
@@ -596,9 +615,12 @@ def test_solve_continuous_lyapunov(shape: tuple[int]):
596615
597616
598617@pytest .mark .parametrize ("shape" , [(5 , 5 ), (5 , 5 , 5 )], ids = ["matrix" , "batched" ])
599- def test_solve_continuous_lyapunov_grad (shape : tuple [int ]):
618+ @pytest .mark .parametrize ("use_complex" , [False , True ], ids = ["float" , "complex" ])
619+ def test_solve_continuous_lyapunov_grad (shape : tuple [int ], use_complex ):
600620 if config .floatX == "float32" :
601621 pytest .skip (reason = "Not enough precision in float32 to get a good gradient" )
622+ if use_complex :
623+ pytest .skip (reason = "Complex numbers are not supported in the gradient test" )
602624
603625 rng = np .random .default_rng (utt .fetch_seed ())
604626 A = rng .normal (size = shape ).astype (config .floatX )
0 commit comments