Skip to content

Commit 1a33bdc

Browse files
Test against complex inputs
1 parent 6c7f1ba commit 1a33bdc

File tree

1 file changed

+38
-16
lines changed

1 file changed

+38
-16
lines changed

tests/tensor/test_slinalg.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -537,15 +537,22 @@ def test_solve_discrete_lyapunov(
537537
precision = int(dtype[-2:]) # 64 or 32
538538
dtype = f"complex{int(2 * precision)}"
539539

540+
A1, A2 = rng.normal(size=(2, *shape)).astype(dtype)
541+
Q1, Q2 = rng.normal(size=(2, *shape)).astype(dtype)
542+
543+
if use_complex:
544+
A = A1 + 1j * A2
545+
Q = Q1 + 1j * Q2
546+
else:
547+
A = A1
548+
Q = Q1
549+
540550
a = pt.tensor(name="a", shape=shape, dtype=dtype)
541551
q = pt.tensor(name="q", shape=shape, dtype=dtype)
542552

543553
x = solve_discrete_lyapunov(a, q, method=method)
544554
f = function([a, q], x)
545555

546-
A = rng.normal(size=shape).astype(dtype)
547-
Q = rng.normal(size=shape).astype(dtype)
548-
549556
X = f(A, Q)
550557
Q_recovered = vec_recover_Q(A, X, continuous=False)
551558

@@ -561,15 +568,12 @@ def test_solve_discrete_lyapunov_gradient(
561568
):
562569
if config.floatX == "float32":
563570
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
564-
565-
rng = np.random.default_rng(utt.fetch_seed())
566-
dtype = config.floatX
567571
if use_complex:
568-
precision = int(dtype[-2:]) # 64 or 32
569-
dtype = f"complex{int(2 * precision)}"
572+
pytest.skip(reason="Complex numbers are not supported in the gradient test")
570573

571-
A = rng.normal(size=shape).astype(dtype)
572-
Q = rng.normal(size=shape).astype(dtype)
574+
rng = np.random.default_rng(utt.fetch_seed())
575+
A = rng.normal(size=shape).astype(config.floatX)
576+
Q = rng.normal(size=shape).astype(config.floatX)
573577

574578
utt.verify_grad(
575579
functools.partial(solve_discrete_lyapunov, method=method),
@@ -579,14 +583,29 @@ def test_solve_discrete_lyapunov_gradient(
579583

580584

581585
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
582-
def test_solve_continuous_lyapunov(shape: tuple[int]):
586+
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
587+
def test_solve_continuous_lyapunov(shape: tuple[int], use_complex: bool):
583588
rng = np.random.default_rng(utt.fetch_seed())
584-
a = pt.tensor(name="a", shape=shape)
585-
q = pt.tensor(name="q", shape=shape)
589+
590+
dtype = config.floatX
591+
if use_complex:
592+
precision = int(dtype[-2:]) # 64 or 32
593+
dtype = f"complex{int(2 * precision)}"
594+
595+
A1, A2 = rng.normal(size=(2, *shape)).astype(dtype)
596+
Q1, Q2 = rng.normal(size=(2, *shape)).astype(dtype)
597+
598+
if use_complex:
599+
A = A1 + 1j * A2
600+
Q = Q1 + 1j * Q2
601+
else:
602+
A = A1
603+
Q = Q1
604+
605+
a = pt.tensor(name="a", shape=shape, dtype=dtype)
606+
q = pt.tensor(name="q", shape=shape, dtype=dtype)
586607
f = function([a, q], [solve_continuous_lyapunov(a, q)])
587608

588-
A = rng.normal(size=shape).astype(config.floatX)
589-
Q = rng.normal(size=shape).astype(config.floatX)
590609
X = f(A, Q)
591610

592611
Q_recovered = vec_recover_Q(A, X, continuous=True)
@@ -596,9 +615,12 @@ def test_solve_continuous_lyapunov(shape: tuple[int]):
596615

597616

598617
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
599-
def test_solve_continuous_lyapunov_grad(shape: tuple[int]):
618+
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
619+
def test_solve_continuous_lyapunov_grad(shape: tuple[int], use_complex):
600620
if config.floatX == "float32":
601621
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
622+
if use_complex:
623+
pytest.skip(reason="Complex numbers are not supported in the gradient test")
602624

603625
rng = np.random.default_rng(utt.fetch_seed())
604626
A = rng.normal(size=shape).astype(config.floatX)

0 commit comments

Comments
 (0)