Skip to content

Commit 593188b

Browse files
Add numba dispatch for LU
1 parent 2ab51d0 commit 593188b

File tree

3 files changed

+297
-3
lines changed

3 files changed

+297
-3
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 264 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import warnings
22
from collections.abc import Callable
3+
from typing import cast as typing_cast
34

45
import numba
56
import numpy as np
7+
import scipy.linalg
68
from numba.core import types
79
from numba.extending import overload
810
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
@@ -18,6 +20,7 @@
1820
)
1921
from pytensor.link.numba.dispatch.basic import numba_funcify
2022
from pytensor.tensor.slinalg import (
23+
LU,
2124
BlockDiagonal,
2225
Cholesky,
2326
CholeskySolve,
@@ -492,10 +495,11 @@ def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
492495
def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
493496
"""
494497
Placeholder for LU factorization; used by linalg.solve.
495-
496-
# TODO: Implement an LU_factor Op, then dispatch to this function in numba mode.
497498
"""
498-
return # type: ignore
499+
getrf = scipy.linalg.get_lapack_funcs("getrf", (A,))
500+
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a)
501+
502+
return A_copy, ipiv
499503

500504

501505
@overload(_getrf)
@@ -531,6 +535,263 @@ def impl(
531535
return impl
532536

533537

538+
def _lu_1(
539+
a: np.ndarray,
540+
permute_l: bool,
541+
check_finite: bool,
542+
p_indices: bool,
543+
overwrite_a: bool,
544+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
545+
"""
546+
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
547+
548+
Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer
549+
array of row swaps, such that L[perm] @ U = A.
550+
"""
551+
return typing_cast(
552+
tuple[np.ndarray, np.ndarray, np.ndarray],
553+
linalg.lu(
554+
a,
555+
permute_l=permute_l,
556+
check_finite=check_finite,
557+
p_indices=p_indices,
558+
overwrite_a=overwrite_a,
559+
),
560+
)
561+
562+
563+
def _lu_2(
564+
a: np.ndarray,
565+
permute_l: bool,
566+
check_finite: bool,
567+
p_indices: bool,
568+
overwrite_a: bool,
569+
) -> tuple[np.ndarray, np.ndarray]:
570+
"""
571+
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
572+
573+
Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the
574+
permuted L matrix, PL = P @ L.
575+
"""
576+
return typing_cast(
577+
tuple[np.ndarray, np.ndarray],
578+
linalg.lu(
579+
a,
580+
permute_l=permute_l,
581+
check_finite=check_finite,
582+
p_indices=p_indices,
583+
overwrite_a=overwrite_a,
584+
),
585+
)
586+
587+
588+
def _lu_3(
589+
a: np.ndarray,
590+
permute_l: bool,
591+
check_finite: bool,
592+
p_indices: bool,
593+
overwrite_a: bool,
594+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
595+
"""
596+
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
597+
598+
Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation
599+
matrix, P @ L @ U = A.
600+
"""
601+
return typing_cast(
602+
tuple[np.ndarray, np.ndarray, np.ndarray],
603+
linalg.lu(
604+
a,
605+
permute_l=permute_l,
606+
check_finite=check_finite,
607+
p_indices=p_indices,
608+
overwrite_a=overwrite_a,
609+
),
610+
)
611+
612+
613+
@overload(_lu_1)
614+
def lu_impl_1(
615+
a: np.ndarray,
616+
permute_l: bool,
617+
check_finite: bool,
618+
p_indices: bool,
619+
overwrite_a: bool,
620+
) -> Callable[
621+
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
622+
]:
623+
"""
624+
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
625+
False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A.
626+
"""
627+
ensure_lapack()
628+
_check_scipy_linalg_matrix(a, "lu")
629+
dtype = a.dtype
630+
631+
def impl(
632+
a: np.ndarray,
633+
permute_l: bool,
634+
check_finite: bool,
635+
p_indices: bool,
636+
overwrite_a: bool,
637+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
638+
A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a)
639+
640+
L = np.eye(A_copy.shape[-1], dtype=dtype)
641+
L += np.tril(A_copy, k=-1)
642+
U = np.triu(A_copy)
643+
644+
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
645+
IPIV = IPIV - 1
646+
p_inv = np.arange(len(IPIV))
647+
for i in range(len(IPIV)):
648+
p_inv[i], p_inv[IPIV[i]] = p_inv[IPIV[i]], p_inv[i]
649+
650+
perm = np.argsort(p_inv)
651+
return perm, L, U
652+
653+
return impl
654+
655+
656+
@overload(_lu_2)
657+
def lu_impl_2(
658+
a: np.ndarray,
659+
permute_l: bool,
660+
check_finite: bool,
661+
p_indices: bool,
662+
overwrite_a: bool,
663+
) -> Callable[[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray]]:
664+
"""
665+
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
666+
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
667+
"""
668+
669+
ensure_lapack()
670+
_check_scipy_linalg_matrix(a, "lu")
671+
dtype = a.dtype
672+
673+
def impl(
674+
a: np.ndarray,
675+
permute_l: bool,
676+
check_finite: bool,
677+
p_indices: bool,
678+
overwrite_a: bool,
679+
) -> tuple[np.ndarray, np.ndarray]:
680+
A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a)
681+
682+
L = np.eye(A_copy.shape[-1], dtype=dtype)
683+
L += np.tril(A_copy, k=-1)
684+
U = np.triu(A_copy)
685+
686+
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
687+
IPIV = IPIV - 1
688+
p_inv = np.arange(len(IPIV))
689+
for i in range(len(IPIV)):
690+
p_inv[i], p_inv[IPIV[i]] = p_inv[IPIV[i]], p_inv[i]
691+
692+
perm = np.argsort(p_inv)
693+
PL = L[perm]
694+
return PL, U
695+
696+
return impl
697+
698+
699+
@overload(_lu_3)
700+
def lu_impl_3(
701+
a: np.ndarray,
702+
permute_l: bool,
703+
check_finite: bool,
704+
p_indices: bool,
705+
overwrite_a: bool,
706+
) -> Callable[
707+
[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray]
708+
]:
709+
"""
710+
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
711+
False. Returns a tuple of (P, L, U), such that P @ L @ U = A.
712+
"""
713+
ensure_lapack()
714+
_check_scipy_linalg_matrix(a, "lu")
715+
dtype = a.dtype
716+
717+
def impl(
718+
a: np.ndarray,
719+
permute_l: bool,
720+
check_finite: bool,
721+
p_indices: bool,
722+
overwrite_a: bool,
723+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
724+
A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a)
725+
726+
L = np.eye(A_copy.shape[-1], dtype=dtype)
727+
L += np.tril(A_copy, k=-1)
728+
U = np.triu(A_copy)
729+
730+
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
731+
IPIV = IPIV - 1
732+
p_inv = np.arange(len(IPIV))
733+
for i in range(len(IPIV)):
734+
p_inv[i], p_inv[IPIV[i]] = p_inv[IPIV[i]], p_inv[i]
735+
736+
perm = np.argsort(p_inv)
737+
P = np.eye(A_copy.shape[-1], dtype=dtype)[perm]
738+
739+
return P, L, U
740+
741+
return impl
742+
743+
744+
@numba_funcify.register(LU)
745+
def numba_funcify_LU(op, node, **kwargs):
746+
permute_l = op.permute_l
747+
check_finite = op.check_finite
748+
p_indices = op.p_indices
749+
overwrite_a = op.overwrite_a
750+
751+
dtype = node.inputs[0].dtype
752+
if str(dtype).startswith("complex"):
753+
raise NotImplementedError(
754+
"Complex inputs not currently supported by lu in Numba mode"
755+
)
756+
757+
@numba_basic.numba_njit(inline="always")
758+
def lu(a):
759+
if check_finite:
760+
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
761+
raise np.linalg.LinAlgError(
762+
"Non-numeric values (nan or inf) found in input to lu"
763+
)
764+
765+
if p_indices:
766+
res = _lu_1(
767+
a,
768+
permute_l=permute_l,
769+
check_finite=check_finite,
770+
p_indices=p_indices,
771+
overwrite_a=overwrite_a,
772+
)
773+
elif permute_l:
774+
res = _lu_2(
775+
a,
776+
permute_l=permute_l,
777+
check_finite=check_finite,
778+
p_indices=p_indices,
779+
overwrite_a=overwrite_a,
780+
)
781+
else:
782+
res = _lu_3(
783+
a,
784+
permute_l=permute_l,
785+
check_finite=check_finite,
786+
p_indices=p_indices,
787+
overwrite_a=overwrite_a,
788+
)
789+
790+
return res
791+
792+
return lu
793+
794+
534795
def _getrs(
535796
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
536797
) -> tuple[np.ndarray, int]:

pytensor/tensor/slinalg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,4 +1447,5 @@ def block_diag(*matrices: TensorVariable):
14471447
"solve_triangular",
14481448
"block_diag",
14491449
"cho_solve",
1450+
"lu",
14501451
]

tests/link/numba/test_slinalg.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,3 +476,35 @@ def test_block_diag():
476476
C_val = np.random.normal(size=(2, 2)).astype(floatX)
477477
D_val = np.random.normal(size=(4, 4)).astype(floatX)
478478
compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val])
479+
480+
481+
@pytest.mark.parametrize(
482+
"permute_l, p_indices",
483+
[(True, False), (False, True), (False, False)],
484+
ids=["PL", "p_indices", "P"],
485+
)
486+
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
487+
def test_numba_lu(permute_l, p_indices, shape: tuple[int]):
488+
rng = np.random.default_rng()
489+
A = pt.tensor(
490+
"A",
491+
shape=shape,
492+
dtype=config.floatX,
493+
)
494+
495+
out = pt.linalg.lu(A, permute_l=permute_l, p_indices=p_indices)
496+
f = pytensor.function([A], out, mode="NUMBA")
497+
498+
A_val = rng.normal(size=shape).astype(config.floatX)
499+
if len(shape) == 2:
500+
compare_numba_and_py([A], out, test_inputs=[A_val], inplace=True)
501+
502+
else:
503+
# compare_numba_and_py fails: NotImplementedError: Non-jitted BlockwiseWithCoreShape not implemented
504+
nb_out = f(A_val.copy())
505+
sp_out = scipy_linalg.lu(
506+
A_val.copy(), permute_l=permute_l, p_indices=p_indices, check_finite=False
507+
)
508+
509+
for a, b in zip(nb_out, sp_out, strict=True):
510+
np.testing.assert_allclose(a, b)

0 commit comments

Comments
 (0)