Skip to content

Commit 93314a5

Browse files
Add numba dispatch for lu_factor
1 parent a45b108 commit 93314a5

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,8 @@ def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
478478

479479
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
480480
"""
481-
Placeholder for LU factorization; used by linalg.solve.
481+
Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also
482+
returns an info code with diagnostic information.
482483
"""
483484
getrf = scipy.linalg.get_lapack_funcs("getrf", (A,))
484485
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a)
@@ -519,6 +520,29 @@ def impl(
519520
return impl
520521

521522

523+
def _lu_factor(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray]:
524+
"""
525+
Thin wrapper around scipy.linalg.lu_factor. Used as an overload target to avoid side-effects on users who import
526+
Pytensor.
527+
"""
528+
return linalg.lu_factor(A, overwrite_a=overwrite_a)
529+
530+
531+
@overload(_lu_factor)
532+
def lu_factor_impl(
533+
A: np.ndarray, overwrite_a: bool = False
534+
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]:
535+
ensure_lapack()
536+
_check_scipy_linalg_matrix(A, "lu_factor")
537+
538+
def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
539+
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
540+
_solve_check(int_ptr_to_val(INFO), 0)
541+
return A_copy, IPIV
542+
543+
return impl
544+
545+
522546
def _lu_1(
523547
a: np.ndarray,
524548
permute_l: bool,

tests/link/numba/test_slinalg.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,10 +523,31 @@ def test_numba_lu(permute_l, p_indices, shape: tuple[int]):
523523

524524
else:
525525
# compare_numba_and_py fails: NotImplementedError: Non-jitted BlockwiseWithCoreShape not implemented
526-
nb_out = f(A_val.copy())
526+
pt_out = f(A_val.copy())
527527
sp_out = scipy_linalg.lu(
528528
A_val.copy(), permute_l=permute_l, p_indices=p_indices, check_finite=False
529529
)
530530

531-
for a, b in zip(nb_out, sp_out, strict=True):
531+
for a, b in zip(pt_out, sp_out, strict=True):
532+
np.testing.assert_allclose(a, b)
533+
534+
535+
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
536+
def test_numba_lu_factor(shape: tuple[int]):
537+
rng = np.random.default_rng(utt.fetch_seed())
538+
A = pt.tensor("A", shape=shape, dtype=config.floatX)
539+
out = pt.linalg.lu_factor(A)
540+
541+
A_val = rng.normal(size=shape).astype(config.floatX)
542+
f = pytensor.function([A], out, mode="NUMBA")
543+
544+
if len(shape) == 2:
545+
compare_numba_and_py([A], out, [A_val], inplace=True)
546+
else:
547+
pt_out = f(A_val.copy())
548+
sp_out = np.vectorize(scipy_linalg.lu_factor, signature="(n,n)->(n,n),(n)")(
549+
A_val.copy()
550+
)
551+
552+
for a, b in zip(pt_out, sp_out, strict=True):
532553
np.testing.assert_allclose(a, b)

0 commit comments

Comments
 (0)