Skip to content

Commit cc92f2d

Browse files
Improve tests
1 parent bdb989c commit cc92f2d

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

tests/link/jax/test_slinalg.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,16 +199,20 @@ def test_jax_eigvalsh(lower):
199199

200200

201201
@pytest.mark.parametrize("method", ["direct", "bilinear"])
202-
def test_jax_solve_discrete_lyapunov(method: Literal["direct", "bilinear"]):
203-
A = matrix("A")
204-
B = matrix("B")
202+
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
203+
def test_jax_solve_discrete_lyapunov(
204+
method: Literal["direct", "bilinear"], shape: tuple[int]
205+
):
206+
A = pt.tensor(name="A", shape=shape)
207+
B = pt.tensor(name="B", shape=shape)
205208
out = pt_slinalg.solve_discrete_lyapunov(A, B, method=method)
206209
out_fg = FunctionGraph([A, B], [out])
207210

208211
compare_jax_and_py(
209212
out_fg,
210213
[
211-
np.random.normal(size=(5, 5)).astype(config.floatX),
212-
np.random.normal(size=(5, 5)).astype(config.floatX),
214+
np.random.normal(size=shape).astype(config.floatX),
215+
np.random.normal(size=shape).astype(config.floatX),
213216
],
217+
jax_mode="JAX",
214218
)

tests/tensor/test_slinalg.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ def recover_Q(A, X, continuous=True):
525525
vec_recover_Q = np.vectorize(recover_Q, signature="(m,m),(m,m),()->(m,m)")
526526

527527

528-
@pytest.mark.parametrize("use_complex", [False, True])
528+
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
529529
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
530530
@pytest.mark.parametrize("method", ["direct", "bilinear"])
531531
@pytest.mark.filterwarnings("ignore::UserWarning")
@@ -541,7 +541,8 @@ def test_solve_discrete_lyapunov(
541541
a = pt.tensor(name="a", shape=shape, dtype=dtype)
542542
q = pt.tensor(name="q", shape=shape, dtype=dtype)
543543

544-
f = function([a, q], solve_discrete_lyapunov(a, q, method=method))
544+
x = solve_discrete_lyapunov(a, q, method=method)
545+
f = function([a, q], x)
545546

546547
A = rng.normal(size=shape)
547548
Q = rng.normal(size=shape)
@@ -551,7 +552,9 @@ def test_solve_discrete_lyapunov(
551552
np.testing.assert_allclose(Q_recovered, Q)
552553

553554
utt.verify_grad(
554-
functools.partial(solve_discrete_lyapunov, method=method), pt=[A, Q], rng=rng
555+
functools.partial(solve_discrete_lyapunov, method=method),
556+
pt=[A, Q],
557+
rng=rng,
555558
)
556559

557560

@@ -588,7 +591,6 @@ def test_solve_discrete_are_forward(add_batch_dim):
588591

589592
x = solve_discrete_are(a, b, q, r)
590593

591-
# A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q
592594
def eval_fun(a, b, q, r, x):
593595
term_1 = a.T @ x @ a
594596
term_2 = a.T @ x @ b

0 commit comments

Comments
 (0)