Skip to content

Commit 8a7f75d

Browse files
Add numba dispatch for lu_factor
1 parent 963ae07 commit 8a7f75d

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,8 @@ def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
478478

479479
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
480480
"""
481-
Placeholder for LU factorization; used by linalg.solve.
481+
Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also
482+
returns an info code with diagnostic information.
482483
"""
483484
getrf = scipy.linalg.get_lapack_funcs("getrf", (A,))
484485
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a)
@@ -519,6 +520,29 @@ def impl(
519520
return impl
520521

521522

523+
def _lu_factor(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray]:
524+
"""
525+
Thin wrapper around scipy.linalg.lu_factor. Used as an overload target to avoid side-effects on users who import
526+
Pytensor.
527+
"""
528+
return linalg.lu_factor(A, overwrite_a=overwrite_a)
529+
530+
531+
@overload(_lu_factor)
532+
def lu_factor_impl(
533+
A: np.ndarray, overwrite_a: bool = False
534+
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]:
535+
ensure_lapack()
536+
_check_scipy_linalg_matrix(A, "lu_factor")
537+
538+
def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
539+
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
540+
_solve_check(int_ptr_to_val(INFO), 0)
541+
return A_copy, IPIV
542+
543+
return impl
544+
545+
522546
def _lu_1(
523547
a: np.ndarray,
524548
permute_l: bool,

tests/link/numba/test_slinalg.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,10 +521,31 @@ def test_numba_lu(permute_l, p_indices, shape: tuple[int]):
521521

522522
else:
523523
# compare_numba_and_py fails: NotImplementedError: Non-jitted BlockwiseWithCoreShape not implemented
524-
nb_out = f(A_val.copy())
524+
pt_out = f(A_val.copy())
525525
sp_out = scipy_linalg.lu(
526526
A_val.copy(), permute_l=permute_l, p_indices=p_indices, check_finite=False
527527
)
528528

529-
for a, b in zip(nb_out, sp_out, strict=True):
529+
for a, b in zip(pt_out, sp_out, strict=True):
530+
np.testing.assert_allclose(a, b)
531+
532+
533+
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
534+
def test_numba_lu_factor(shape: tuple[int]):
535+
rng = np.random.default_rng(utt.fetch_seed())
536+
A = pt.tensor("A", shape=shape, dtype=config.floatX)
537+
out = pt.linalg.lu_factor(A)
538+
539+
A_val = rng.normal(size=shape).astype(config.floatX)
540+
f = pytensor.function([A], out, mode="NUMBA")
541+
542+
if len(shape) == 2:
543+
compare_numba_and_py([A], out, [A_val], inplace=True)
544+
else:
545+
pt_out = f(A_val.copy())
546+
sp_out = np.vectorize(scipy_linalg.lu_factor, signature="(n,n)->(n,n),(n)")(
547+
A_val.copy()
548+
)
549+
550+
for a, b in zip(pt_out, sp_out, strict=True):
530551
np.testing.assert_allclose(a, b)

0 commit comments

Comments
 (0)