Skip to content

Commit 70ef520

Browse files
Add numba dispatch for LU
1 parent da924bf commit 70ef520

File tree

3 files changed

+297
-3
lines changed

3 files changed

+297
-3
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 264 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import warnings
22
from collections.abc import Callable
3+
from typing import cast as typing_cast
34

45
import numba
56
import numpy as np
7+
import scipy.linalg
68
from numba.core import types
79
from numba.extending import overload
810
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
@@ -18,6 +20,7 @@
1820
)
1921
from pytensor.link.numba.dispatch.basic import numba_funcify
2022
from pytensor.tensor.slinalg import (
23+
LU,
2124
BlockDiagonal,
2225
Cholesky,
2326
CholeskySolve,
@@ -476,10 +479,11 @@ def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
476479
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
477480
"""
478481
Placeholder for LU factorization; used by linalg.solve.
479-
480-
# TODO: Implement an LU_factor Op, then dispatch to this function in numba mode.
481482
"""
482-
return # type: ignore
483+
getrf = scipy.linalg.get_lapack_funcs("getrf", (A,))
484+
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a)
485+
486+
return A_copy, ipiv
483487

484488

485489
@overload(_getrf)
@@ -515,6 +519,263 @@ def impl(
515519
return impl
516520

517521

522+
def _lu_1(
523+
a: np.ndarray,
524+
permute_l: bool,
525+
check_finite: bool,
526+
p_indices: bool,
527+
overwrite_a: bool,
528+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
529+
"""
530+
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
531+
532+
Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer
533+
array of row swaps, such that L[perm] @ U = A.
534+
"""
535+
return typing_cast(
536+
tuple[np.ndarray, np.ndarray, np.ndarray],
537+
linalg.lu(
538+
a,
539+
permute_l=permute_l,
540+
check_finite=check_finite,
541+
p_indices=p_indices,
542+
overwrite_a=overwrite_a,
543+
),
544+
)
545+
546+
547+
def _lu_2(
548+
a: np.ndarray,
549+
permute_l: bool,
550+
check_finite: bool,
551+
p_indices: bool,
552+
overwrite_a: bool,
553+
) -> tuple[np.ndarray, np.ndarray]:
554+
"""
555+
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
556+
557+
Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the
558+
permuted L matrix, PL = P @ L.
559+
"""
560+
return typing_cast(
561+
tuple[np.ndarray, np.ndarray],
562+
linalg.lu(
563+
a,
564+
permute_l=permute_l,
565+
check_finite=check_finite,
566+
p_indices=p_indices,
567+
overwrite_a=overwrite_a,
568+
),
569+
)
570+
571+
572+
def _lu_3(
573+
a: np.ndarray,
574+
permute_l: bool,
575+
check_finite: bool,
576+
p_indices: bool,
577+
overwrite_a: bool,
578+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
579+
"""
580+
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
581+
582+
Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation
583+
matrix, P @ L @ U = A.
584+
"""
585+
return typing_cast(
586+
tuple[np.ndarray, np.ndarray, np.ndarray],
587+
linalg.lu(
588+
a,
589+
permute_l=permute_l,
590+
check_finite=check_finite,
591+
p_indices=p_indices,
592+
overwrite_a=overwrite_a,
593+
),
594+
)
595+
596+
597+
@overload(_lu_1)
598+
def lu_impl_1(
599+
a: np.ndarray,
600+
permute_l: bool,
601+
check_finite: bool,
602+
p_indices: bool,
603+
overwrite_a: bool,
604+
) -> Callable[
605+
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
606+
]:
607+
"""
608+
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
609+
False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A.
610+
"""
611+
ensure_lapack()
612+
_check_scipy_linalg_matrix(a, "lu")
613+
dtype = a.dtype
614+
615+
def impl(
616+
a: np.ndarray,
617+
permute_l: bool,
618+
check_finite: bool,
619+
p_indices: bool,
620+
overwrite_a: bool,
621+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
622+
A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a)
623+
624+
L = np.eye(A_copy.shape[-1], dtype=dtype)
625+
L += np.tril(A_copy, k=-1)
626+
U = np.triu(A_copy)
627+
628+
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
629+
IPIV = IPIV - 1
630+
p_inv = np.arange(len(IPIV))
631+
for i in range(len(IPIV)):
632+
p_inv[i], p_inv[IPIV[i]] = p_inv[IPIV[i]], p_inv[i]
633+
634+
perm = np.argsort(p_inv)
635+
return perm, L, U
636+
637+
return impl
638+
639+
640+
@overload(_lu_2)
641+
def lu_impl_2(
642+
a: np.ndarray,
643+
permute_l: bool,
644+
check_finite: bool,
645+
p_indices: bool,
646+
overwrite_a: bool,
647+
) -> Callable[[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray]]:
648+
"""
649+
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
650+
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
651+
"""
652+
653+
ensure_lapack()
654+
_check_scipy_linalg_matrix(a, "lu")
655+
dtype = a.dtype
656+
657+
def impl(
658+
a: np.ndarray,
659+
permute_l: bool,
660+
check_finite: bool,
661+
p_indices: bool,
662+
overwrite_a: bool,
663+
) -> tuple[np.ndarray, np.ndarray]:
664+
A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a)
665+
666+
L = np.eye(A_copy.shape[-1], dtype=dtype)
667+
L += np.tril(A_copy, k=-1)
668+
U = np.triu(A_copy)
669+
670+
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
671+
IPIV = IPIV - 1
672+
p_inv = np.arange(len(IPIV))
673+
for i in range(len(IPIV)):
674+
p_inv[i], p_inv[IPIV[i]] = p_inv[IPIV[i]], p_inv[i]
675+
676+
perm = np.argsort(p_inv)
677+
PL = L[perm]
678+
return PL, U
679+
680+
return impl
681+
682+
683+
@overload(_lu_3)
684+
def lu_impl_3(
685+
a: np.ndarray,
686+
permute_l: bool,
687+
check_finite: bool,
688+
p_indices: bool,
689+
overwrite_a: bool,
690+
) -> Callable[
691+
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
692+
]:
693+
"""
694+
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
695+
False. Returns a tuple of (P, L, U), such that P @ L @ U = A.
696+
"""
697+
ensure_lapack()
698+
_check_scipy_linalg_matrix(a, "lu")
699+
dtype = a.dtype
700+
701+
def impl(
702+
a: np.ndarray,
703+
permute_l: bool,
704+
check_finite: bool,
705+
p_indices: bool,
706+
overwrite_a: bool,
707+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
708+
A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a)
709+
710+
L = np.eye(A_copy.shape[-1], dtype=dtype)
711+
L += np.tril(A_copy, k=-1)
712+
U = np.triu(A_copy)
713+
714+
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
715+
IPIV = IPIV - 1
716+
p_inv = np.arange(len(IPIV))
717+
for i in range(len(IPIV)):
718+
p_inv[i], p_inv[IPIV[i]] = p_inv[IPIV[i]], p_inv[i]
719+
720+
perm = np.argsort(p_inv)
721+
P = np.eye(A_copy.shape[-1], dtype=dtype)[perm]
722+
723+
return P, L, U
724+
725+
return impl
726+
727+
728+
@numba_funcify.register(LU)
729+
def numba_funcify_LU(op, node, **kwargs):
730+
permute_l = op.permute_l
731+
check_finite = op.check_finite
732+
p_indices = op.p_indices
733+
overwrite_a = op.overwrite_a
734+
735+
dtype = node.inputs[0].dtype
736+
if str(dtype).startswith("complex"):
737+
raise NotImplementedError(
738+
"Complex inputs not currently supported by lu in Numba mode"
739+
)
740+
741+
@numba_basic.numba_njit(inline="always")
742+
def lu(a):
743+
if check_finite:
744+
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
745+
raise np.linalg.LinAlgError(
746+
"Non-numeric values (nan or inf) found in input to lu"
747+
)
748+
749+
if p_indices:
750+
res = _lu_1(
751+
a,
752+
permute_l=permute_l,
753+
check_finite=check_finite,
754+
p_indices=p_indices,
755+
overwrite_a=overwrite_a,
756+
)
757+
elif permute_l:
758+
res = _lu_2(
759+
a,
760+
permute_l=permute_l,
761+
check_finite=check_finite,
762+
p_indices=p_indices,
763+
overwrite_a=overwrite_a,
764+
)
765+
else:
766+
res = _lu_3(
767+
a,
768+
permute_l=permute_l,
769+
check_finite=check_finite,
770+
p_indices=p_indices,
771+
overwrite_a=overwrite_a,
772+
)
773+
774+
return res
775+
776+
return lu
777+
778+
518779
def _getrs(
519780
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
520781
) -> tuple[np.ndarray, int]:

pytensor/tensor/slinalg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,4 +1447,5 @@ def block_diag(*matrices: TensorVariable):
14471447
"solve_triangular",
14481448
"block_diag",
14491449
"cho_solve",
1450+
"lu",
14501451
]

tests/link/numba/test_slinalg.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,3 +496,35 @@ def test_cho_solve(b_func, b_size, lower):
496496
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
497497

498498
np.testing.assert_allclose(A @ X_np, b, atol=ATOL, rtol=RTOL)
499+
500+
501+
@pytest.mark.parametrize(
502+
"permute_l, p_indices",
503+
[(True, False), (False, True), (False, False)],
504+
ids=["PL", "p_indices", "P"],
505+
)
506+
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
507+
def test_numba_lu(permute_l, p_indices, shape: tuple[int]):
508+
rng = np.random.default_rng()
509+
A = pt.tensor(
510+
"A",
511+
shape=shape,
512+
dtype=config.floatX,
513+
)
514+
515+
out = pt.linalg.lu(A, permute_l=permute_l, p_indices=p_indices)
516+
f = pytensor.function([A], out, mode="NUMBA")
517+
518+
A_val = rng.normal(size=shape).astype(config.floatX)
519+
if len(shape) == 2:
520+
compare_numba_and_py([A], out, test_inputs=[A_val], inplace=True)
521+
522+
else:
523+
# compare_numba_and_py fails: NotImplementedError: Non-jitted BlockwiseWithCoreShape not implemented
524+
nb_out = f(A_val.copy())
525+
sp_out = scipy_linalg.lu(
526+
A_val.copy(), permute_l=permute_l, p_indices=p_indices, check_finite=False
527+
)
528+
529+
for a, b in zip(nb_out, sp_out, strict=True):
530+
np.testing.assert_allclose(a, b)

0 commit comments

Comments
 (0)