Skip to content

Commit aa3199b

Browse files
Add numba dispatch for LU
1 parent a44c0a2 commit aa3199b

File tree

3 files changed

+297
-3
lines changed

3 files changed

+297
-3
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 264 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from collections.abc import Callable
2+
from typing import cast as typing_cast
23

34
import numba
45
import numpy as np
6+
import scipy.linalg
57
from numba.core import types
68
from numba.extending import overload
79
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
@@ -17,6 +19,7 @@
1719
)
1820
from pytensor.link.numba.dispatch.basic import numba_funcify
1921
from pytensor.tensor.slinalg import (
22+
LU,
2023
BlockDiagonal,
2124
Cholesky,
2225
CholeskySolve,
@@ -470,10 +473,11 @@ def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
470473
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
471474
"""
472475
Placeholder for LU factorization; used by linalg.solve.
473-
474-
# TODO: Implement an LU_factor Op, then dispatch to this function in numba mode.
475476
"""
476-
return # type: ignore
477+
getrf = scipy.linalg.get_lapack_funcs("getrf", (A,))
478+
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a)
479+
480+
return A_copy, ipiv
477481

478482

479483
@overload(_getrf)
@@ -509,6 +513,263 @@ def impl(
509513
return impl
510514

511515

516+
def _lu_1(
517+
a: np.ndarray,
518+
permute_l: bool,
519+
check_finite: bool,
520+
p_indices: bool,
521+
overwrite_a: bool,
522+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
523+
"""
524+
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
525+
526+
Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer
527+
array of row swaps, such that L[perm] @ U = A.
528+
"""
529+
return typing_cast(
530+
tuple[np.ndarray, np.ndarray, np.ndarray],
531+
linalg.lu(
532+
a,
533+
permute_l=permute_l,
534+
check_finite=check_finite,
535+
p_indices=p_indices,
536+
overwrite_a=overwrite_a,
537+
),
538+
)
539+
540+
541+
def _lu_2(
542+
a: np.ndarray,
543+
permute_l: bool,
544+
check_finite: bool,
545+
p_indices: bool,
546+
overwrite_a: bool,
547+
) -> tuple[np.ndarray, np.ndarray]:
548+
"""
549+
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
550+
551+
Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the
552+
permuted L matrix, PL = P @ L.
553+
"""
554+
return typing_cast(
555+
tuple[np.ndarray, np.ndarray],
556+
linalg.lu(
557+
a,
558+
permute_l=permute_l,
559+
check_finite=check_finite,
560+
p_indices=p_indices,
561+
overwrite_a=overwrite_a,
562+
),
563+
)
564+
565+
566+
def _lu_3(
567+
a: np.ndarray,
568+
permute_l: bool,
569+
check_finite: bool,
570+
p_indices: bool,
571+
overwrite_a: bool,
572+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
573+
"""
574+
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
575+
576+
Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation
577+
matrix, P @ L @ U = A.
578+
"""
579+
return typing_cast(
580+
tuple[np.ndarray, np.ndarray, np.ndarray],
581+
linalg.lu(
582+
a,
583+
permute_l=permute_l,
584+
check_finite=check_finite,
585+
p_indices=p_indices,
586+
overwrite_a=overwrite_a,
587+
),
588+
)
589+
590+
591+
@overload(_lu_1)
592+
def lu_impl_1(
593+
a: np.ndarray,
594+
permute_l: bool,
595+
check_finite: bool,
596+
p_indices: bool,
597+
overwrite_a: bool,
598+
) -> Callable[
599+
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
600+
]:
601+
"""
602+
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
603+
False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A.
604+
"""
605+
ensure_lapack()
606+
_check_scipy_linalg_matrix(a, "lu")
607+
dtype = a.dtype
608+
609+
def impl(
610+
a: np.ndarray,
611+
permute_l: bool,
612+
check_finite: bool,
613+
p_indices: bool,
614+
overwrite_a: bool,
615+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
616+
A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a)
617+
618+
L = np.eye(A_copy.shape[-1], dtype=dtype)
619+
L += np.tril(A_copy, k=-1)
620+
U = np.triu(A_copy)
621+
622+
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
623+
IPIV = IPIV - 1
624+
p_inv = np.arange(len(IPIV))
625+
for i in range(len(IPIV)):
626+
p_inv[i], p_inv[IPIV[i]] = p_inv[IPIV[i]], p_inv[i]
627+
628+
perm = np.argsort(p_inv)
629+
return perm, L, U
630+
631+
return impl
632+
633+
634+
@overload(_lu_2)
635+
def lu_impl_2(
636+
a: np.ndarray,
637+
permute_l: bool,
638+
check_finite: bool,
639+
p_indices: bool,
640+
overwrite_a: bool,
641+
) -> Callable[[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray]]:
642+
"""
643+
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
644+
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
645+
"""
646+
647+
ensure_lapack()
648+
_check_scipy_linalg_matrix(a, "lu")
649+
dtype = a.dtype
650+
651+
def impl(
652+
a: np.ndarray,
653+
permute_l: bool,
654+
check_finite: bool,
655+
p_indices: bool,
656+
overwrite_a: bool,
657+
) -> tuple[np.ndarray, np.ndarray]:
658+
A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a)
659+
660+
L = np.eye(A_copy.shape[-1], dtype=dtype)
661+
L += np.tril(A_copy, k=-1)
662+
U = np.triu(A_copy)
663+
664+
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
665+
IPIV = IPIV - 1
666+
p_inv = np.arange(len(IPIV))
667+
for i in range(len(IPIV)):
668+
p_inv[i], p_inv[IPIV[i]] = p_inv[IPIV[i]], p_inv[i]
669+
670+
perm = np.argsort(p_inv)
671+
PL = L[perm]
672+
return PL, U
673+
674+
return impl
675+
676+
677+
@overload(_lu_3)
678+
def lu_impl_3(
679+
a: np.ndarray,
680+
permute_l: bool,
681+
check_finite: bool,
682+
p_indices: bool,
683+
overwrite_a: bool,
684+
) -> Callable[
685+
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
686+
]:
687+
"""
688+
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
689+
False. Returns a tuple of (P, L, U), such that P @ L @ U = A.
690+
"""
691+
ensure_lapack()
692+
_check_scipy_linalg_matrix(a, "lu")
693+
dtype = a.dtype
694+
695+
def impl(
696+
a: np.ndarray,
697+
permute_l: bool,
698+
check_finite: bool,
699+
p_indices: bool,
700+
overwrite_a: bool,
701+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
702+
A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a)
703+
704+
L = np.eye(A_copy.shape[-1], dtype=dtype)
705+
L += np.tril(A_copy, k=-1)
706+
U = np.triu(A_copy)
707+
708+
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
709+
IPIV = IPIV - 1
710+
p_inv = np.arange(len(IPIV))
711+
for i in range(len(IPIV)):
712+
p_inv[i], p_inv[IPIV[i]] = p_inv[IPIV[i]], p_inv[i]
713+
714+
perm = np.argsort(p_inv)
715+
P = np.eye(A_copy.shape[-1], dtype=dtype)[perm]
716+
717+
return P, L, U
718+
719+
return impl
720+
721+
722+
@numba_funcify.register(LU)
723+
def numba_funcify_LU(op, node, **kwargs):
724+
permute_l = op.permute_l
725+
check_finite = op.check_finite
726+
p_indices = op.p_indices
727+
overwrite_a = op.overwrite_a
728+
729+
dtype = node.inputs[0].dtype
730+
if str(dtype).startswith("complex"):
731+
raise NotImplementedError(
732+
"Complex inputs not currently supported by lu in Numba mode"
733+
)
734+
735+
@numba_basic.numba_njit(inline="always")
736+
def lu(a):
737+
if check_finite:
738+
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
739+
raise np.linalg.LinAlgError(
740+
"Non-numeric values (nan or inf) found in input to lu"
741+
)
742+
743+
if p_indices:
744+
res = _lu_1(
745+
a,
746+
permute_l=permute_l,
747+
check_finite=check_finite,
748+
p_indices=p_indices,
749+
overwrite_a=overwrite_a,
750+
)
751+
elif permute_l:
752+
res = _lu_2(
753+
a,
754+
permute_l=permute_l,
755+
check_finite=check_finite,
756+
p_indices=p_indices,
757+
overwrite_a=overwrite_a,
758+
)
759+
else:
760+
res = _lu_3(
761+
a,
762+
permute_l=permute_l,
763+
check_finite=check_finite,
764+
p_indices=p_indices,
765+
overwrite_a=overwrite_a,
766+
)
767+
768+
return res
769+
770+
return lu
771+
772+
512773
def _getrs(
513774
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
514775
) -> tuple[np.ndarray, int]:

pytensor/tensor/slinalg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,4 +1361,5 @@ def block_diag(*matrices: TensorVariable):
13611361
"solve_triangular",
13621362
"block_diag",
13631363
"cho_solve",
1364+
"lu",
13641365
]

tests/link/numba/test_slinalg.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,3 +455,35 @@ def test_cho_solve(b_func, b_size, lower):
455455
RTOL = 1e-8 if floatX.endswith("64") else 1e-4
456456

457457
np.testing.assert_allclose(A @ X_np, b, atol=ATOL, rtol=RTOL)
458+
459+
460+
@pytest.mark.parametrize(
461+
"permute_l, p_indices",
462+
[(True, False), (False, True), (False, False)],
463+
ids=["PL", "p_indices", "P"],
464+
)
465+
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
466+
def test_numba_lu(permute_l, p_indices, shape: tuple[int]):
467+
rng = np.random.default_rng()
468+
A = pt.tensor(
469+
"A",
470+
shape=shape,
471+
dtype=config.floatX,
472+
)
473+
474+
out = pt.linalg.lu(A, permute_l=permute_l, p_indices=p_indices)
475+
f = pytensor.function([A], out, mode="NUMBA")
476+
477+
A_val = rng.normal(size=shape).astype(config.floatX)
478+
if len(shape) == 2:
479+
compare_numba_and_py([A], out, test_inputs=[A_val], inplace=True)
480+
481+
else:
482+
# compare_numba_and_py fails: NotImplementedError: Non-jitted BlockwiseWithCoreShape not implemented
483+
nb_out = f(A_val.copy())
484+
sp_out = scipy_linalg.lu(
485+
A_val.copy(), permute_l=permute_l, p_indices=p_indices, check_finite=False
486+
)
487+
488+
for a, b in zip(nb_out, sp_out, strict=True):
489+
np.testing.assert_allclose(a, b)

0 commit comments

Comments
 (0)