Skip to content

Commit 65082f2

Browse files
Add numba dispatch for lu_factor
1 parent 9cb431e commit 65082f2

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,8 @@ def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
494494

495495
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
496496
"""
497-
Placeholder for LU factorization; used by linalg.solve.
497+
Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also
498+
returns an info code with diagnostic information.
498499
"""
499500
getrf = scipy.linalg.get_lapack_funcs("getrf", (A,))
500501
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a)
@@ -535,6 +536,29 @@ def impl(
535536
return impl
536537

537538

539+
def _lu_factor(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray]:
540+
"""
541+
Thin wrapper around scipy.linalg.lu_factor. Used as an overload target to avoid side-effects on users who import
542+
Pytensor.
543+
"""
544+
return linalg.lu_factor(A, overwrite_a=overwrite_a)
545+
546+
547+
@overload(_lu_factor)
548+
def lu_factor_impl(
549+
A: np.ndarray, overwrite_a: bool = False
550+
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]:
551+
ensure_lapack()
552+
_check_scipy_linalg_matrix(A, "lu_factor")
553+
554+
def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
555+
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
556+
_solve_check(int_ptr_to_val(INFO), 0)
557+
return A_copy, IPIV
558+
559+
return impl
560+
561+
538562
def _lu_1(
539563
a: np.ndarray,
540564
permute_l: bool,

tests/link/numba/test_slinalg.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,10 +501,31 @@ def test_numba_lu(permute_l, p_indices, shape: tuple[int]):
501501

502502
else:
503503
# compare_numba_and_py fails: NotImplementedError: Non-jitted BlockwiseWithCoreShape not implemented
504-
nb_out = f(A_val.copy())
504+
pt_out = f(A_val.copy())
505505
sp_out = scipy_linalg.lu(
506506
A_val.copy(), permute_l=permute_l, p_indices=p_indices, check_finite=False
507507
)
508508

509-
for a, b in zip(nb_out, sp_out, strict=True):
509+
for a, b in zip(pt_out, sp_out, strict=True):
510+
np.testing.assert_allclose(a, b)
511+
512+
513+
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
514+
def test_numba_lu_factor(shape: tuple[int]):
515+
rng = np.random.default_rng(utt.fetch_seed())
516+
A = pt.tensor("A", shape=shape, dtype=config.floatX)
517+
out = pt.linalg.lu_factor(A)
518+
519+
A_val = rng.normal(size=shape).astype(config.floatX)
520+
f = pytensor.function([A], out, mode="NUMBA")
521+
522+
if len(shape) == 2:
523+
compare_numba_and_py([A], out, [A_val], inplace=True)
524+
else:
525+
pt_out = f(A_val.copy())
526+
sp_out = np.vectorize(scipy_linalg.lu_factor, signature="(n,n)->(n,n),(n)")(
527+
A_val.copy()
528+
)
529+
530+
for a, b in zip(pt_out, sp_out, strict=True):
510531
np.testing.assert_allclose(a, b)

0 commit comments

Comments
 (0)