Skip to content

Commit 3aaf869

Browse files
start gradient tests
1 parent a9c5657 commit 3aaf869

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

tests/link/numba/test_slinalg.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytensor
88
import pytensor.tensor as pt
99
from pytensor.graph import FunctionGraph
10+
from tests import unittest_tools as utt
1011
from tests.link.numba.test_basic import compare_numba_and_py
1112

1213

@@ -157,6 +158,16 @@ def test_numba_Cholesky_raise_on(on_error):
157158
assert np.all(np.isnan(f(test_value)))
158159

159160

161+
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
162+
@pytest.mark.parametrize("trans", [True, False], ids=["trans=True", "trans=False"])
163+
def test_numba_Cholesky_grad(lower, trans):
164+
rng = np.random.default_rng(utt.fetch_seed())
165+
L = rng.random.normal(size=(5, 5)).astype(floatX)
166+
X = L @ L.T
167+
168+
utt.verify_grad(pt.linalg.cholesky, [X])
169+
170+
160171
def test_block_diag():
161172
A = pt.matrix("A")
162173
B = pt.matrix("B")
@@ -358,6 +369,8 @@ def test_solve(b_func, b_size, assume_a, transposed):
358369
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
359370
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
360371

372+
# Test gradients
373+
361374

362375
@pytest.mark.parametrize(
363376
"b_func, b_size",

0 commit comments

Comments
 (0)