Skip to content

Commit 289b597

Browse files
Test numba solve gradients
1 parent acec22d commit 289b597

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

pytensor/link/numba/dispatch/_LAPACK.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def numba_xgetrf(cls, dtype):
279279
def numba_xgetrs(cls, dtype):
280280
"""
281281
Solve a system of linear equations A @ X = B or A.T @ X = B with a general N-by-N matrix A using the LU
282-
factorization computed by numba_getrf.
282+
factorization computed by GETRF.
283283
284284
Called by scipy.linalg.lu_solve
285285
"""
@@ -302,8 +302,8 @@ def numba_xgetrs(cls, dtype):
302302
@classmethod
303303
def numba_xsysv(cls, dtype):
304304
"""
305-
Solve a system of linear equations A @ X = B with a symmetric matrix A using the factorization computed by
306-
sytrf (LDL or UDU).
305+
Solve a system of linear equations A @ X = B with a symmetric matrix A using the diagonal pivoting method,
306+
factorizing A into LDL^T or UDU^T form, depending on the value of UPLO
307307
308308
Called by scipy.linalg.solve when assume_a == "sym"
309309
"""
@@ -327,10 +327,8 @@ def numba_xsysv(cls, dtype):
327327
@classmethod
328328
def numba_xsycon(cls, dtype):
329329
"""
330-
Estimates the reciprocal of the condition number of a symmetric matrix A using the factorization computed by
331-
sytrf (LDL or UDU).
332-
333-
Called by scipy.linalg.solve when assume_a == "sym"
330+
Estimate the reciprocal of the condition number of a symmetric matrix A using the UDU or LDL factorization
331+
computed by xSYTRF.
334332
"""
335333
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sycon")
336334

tests/link/numba/test_slinalg.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
from functools import partial
23

34
import numpy as np
45
import pytest
@@ -162,7 +163,7 @@ def test_numba_Cholesky_raise_on(on_error):
162163
@pytest.mark.parametrize("trans", [True, False], ids=["trans=True", "trans=False"])
163164
def test_numba_Cholesky_grad(lower, trans):
164165
rng = np.random.default_rng(utt.fetch_seed())
165-
L = rng.random.normal(size=(5, 5)).astype(floatX)
166+
L = rng.normal(size=(5, 5)).astype(floatX)
166167
X = L @ L.T
167168

168169
utt.verify_grad(pt.linalg.cholesky, [X])
@@ -358,9 +359,11 @@ def test_solve(b_func, b_size, assume_a, transposed):
358359

359360
X_np = f(A_val, b_val)
360361
op = f.maker.fgraph.outputs[0].owner.op
362+
361363
# overwrite_b is preferred when both inputs can be destroyed
362364
assert op.destroy_map == {0: [1]}
363365

366+
# Test that the result is numerically correct
364367
np.testing.assert_allclose(
365368
transpose_func(A_val_copy, transposed) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL
366369
)
@@ -370,6 +373,15 @@ def test_solve(b_func, b_size, assume_a, transposed):
370373
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
371374

372375
# Test gradients
376+
solve = partial(
377+
pt.linalg.solve,
378+
lower=False,
379+
assume_a=assume_a,
380+
transposed=transposed,
381+
b_ndim=len(b_size),
382+
)
383+
384+
utt.verify_grad(solve, [A_val_copy, b_val_copy], mode="NUMBA")
373385

374386

375387
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)