Skip to content

Commit b1f8c9d

Browse files
Add numba dispatch for lu_factor
1 parent 87c5368 commit b1f8c9d

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,8 @@ def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
472472

473473
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
474474
"""
475-
Placeholder for LU factorization; used by linalg.solve.
475+
Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also
476+
returns an info code with diagnostic information.
476477
"""
477478
getrf = scipy.linalg.get_lapack_funcs("getrf", (A,))
478479
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a)
@@ -513,6 +514,29 @@ def impl(
513514
return impl
514515

515516

517+
def _lu_factor(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray]:
518+
"""
519+
Thin wrapper around scipy.linalg.lu_factor. Used as an overload target to avoid side-effects on users who import
520+
Pytensor.
521+
"""
522+
return linalg.lu_factor(A, overwrite_a=overwrite_a)
523+
524+
525+
@overload(_lu_factor)
526+
def lu_factor_impl(
527+
A: np.ndarray, overwrite_a: bool = False
528+
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]:
529+
ensure_lapack()
530+
_check_scipy_linalg_matrix(A, "lu_factor")
531+
532+
def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
533+
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
534+
_solve_check(int_ptr_to_val(INFO), 0)
535+
return A_copy, IPIV
536+
537+
return impl
538+
539+
516540
def _lu_1(
517541
a: np.ndarray,
518542
permute_l: bool,

tests/link/numba/test_slinalg.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,10 +480,31 @@ def test_numba_lu(permute_l, p_indices, shape: tuple[int]):
480480

481481
else:
482482
# compare_numba_and_py fails: NotImplementedError: Non-jitted BlockwiseWithCoreShape not implemented
483-
nb_out = f(A_val.copy())
483+
pt_out = f(A_val.copy())
484484
sp_out = scipy_linalg.lu(
485485
A_val.copy(), permute_l=permute_l, p_indices=p_indices, check_finite=False
486486
)
487487

488-
for a, b in zip(nb_out, sp_out, strict=True):
488+
for a, b in zip(pt_out, sp_out, strict=True):
489+
np.testing.assert_allclose(a, b)
490+
491+
492+
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
493+
def test_numba_lu_factor(shape: tuple[int]):
494+
rng = np.random.default_rng(utt.fetch_seed())
495+
A = pt.tensor("A", shape=shape, dtype=config.floatX)
496+
out = pt.linalg.lu_factor(A)
497+
498+
A_val = rng.normal(size=shape).astype(config.floatX)
499+
f = pytensor.function([A], out, mode="NUMBA")
500+
501+
if len(shape) == 2:
502+
compare_numba_and_py([A], out, [A_val], inplace=True)
503+
else:
504+
pt_out = f(A_val.copy())
505+
sp_out = np.vectorize(scipy_linalg.lu_factor, signature="(n,n)->(n,n),(n)")(
506+
A_val.copy()
507+
)
508+
509+
for a, b in zip(pt_out, sp_out, strict=True):
489510
np.testing.assert_allclose(a, b)

0 commit comments

Comments
 (0)