Skip to content

Commit 191b755

Browse files
committed
Implement specialized PivotToPermutation
1 parent fb26047 commit 191b755

File tree

2 files changed

+44
-32
lines changed

2 files changed

+44
-32
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
int_ptr_to_val,
1919
val_to_int_ptr,
2020
)
21-
from pytensor.link.numba.dispatch.basic import numba_funcify
21+
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit
2222
from pytensor.tensor.slinalg import (
2323
LU,
2424
BlockDiagonal,
2525
Cholesky,
2626
CholeskySolve,
27+
LUFactor,
28+
PivotToPermutations,
2729
Solve,
2830
SolveTriangular,
2931
)
@@ -36,6 +38,18 @@
3638
)
3739

3840

41+
@numba_funcify.register(PivotToPermutations)
42+
def pivot_to_permutation(op, node, **kwargs):
43+
@numba_basic.numba_njit
44+
def pivot_to_permutation(p):
45+
p_inv = np.arange(len(p))
46+
for i in range(len(p)):
47+
p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i]
48+
return p_inv
49+
50+
return pivot_to_permutation
51+
52+
3953
@numba_basic.numba_njit(inline="always")
4054
def _copy_to_fortran_order_even_if_1d(x):
4155
# Numba's _copy_to_fortran_order doesn't do anything for vectors
@@ -818,6 +832,20 @@ def lu(a):
818832
return lu
819833

820834

835+
@numba_funcify.register(LUFactor)
836+
def numba_funcify_LUFactor(op, node, **kwargs):
837+
overwrite_a = op.overwrite_a
838+
839+
@numba_njit
840+
def lu_factor(A):
841+
N = A.shape[1]
842+
LU, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
843+
_solve_check(N, INFO)
844+
return LU, IPIV - 1
845+
846+
return lu_factor
847+
848+
821849
def _getrs(
822850
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
823851
) -> tuple[np.ndarray, int]:

pytensor/tensor/slinalg.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pytensor.tensor.blockwise import Blockwise
2222
from pytensor.tensor.nlinalg import kron, matrix_dot
2323
from pytensor.tensor.shape import reshape
24-
from pytensor.tensor.type import matrix, tensor, vector
24+
from pytensor.tensor.type import ivector, matrix, tensor, vector
2525
from pytensor.tensor.variable import TensorVariable
2626

2727

@@ -583,31 +583,21 @@ def lu(
583583
)
584584

585585

586-
def _pivot_to_permutation(pivots):
587-
"""
588-
Converts a sequence of row exchanges to a permutation matrix that represents the same row exchanges. This
589-
represents the inverse permutation, which can be used to reconstruct the original matrix from its LU factorization.
590-
To get the actual permutation, the inverse permutation must be argsorted.
591-
"""
586+
class PivotToPermutations(Op):
587+
itypes = [ivector]
588+
otypes = [ivector]
589+
590+
__props__ = ()
591+
592+
def perform(self, node, inputs, outputs):
593+
[p] = inputs
594+
p_inv = np.arange(len(p))
595+
for i in range(len(p)):
596+
p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i]
597+
outputs[0][0] = p_inv
592598

593-
def step(i, permutation, swaps):
594-
j = swaps[i]
595-
x = permutation[i]
596-
y = permutation[j]
597-
598-
permutation = permutation[i].set(y)
599-
return permutation[j].set(x)
600-
601-
pivots = as_tensor_variable(pivots)
602-
n = pivots.shape[0]
603-
p_inv, _ = pytensor.scan(
604-
step,
605-
sequences=[pt.arange(n.copy())],
606-
outputs_info=[pt.arange(n.copy())],
607-
non_sequences=[pivots],
608-
)
609599

610-
return p_inv[-1]
600+
_pivot_to_permutation = PivotToPermutations()
611601

612602

613603
class LUFactor(Op):
@@ -810,13 +800,7 @@ def lu_solve(
810800
)
811801
x = x[pt.argsort(inv_permutation)] if trans else x
812802

813-
return LUSolve(
814-
inputs=[LU, pivots, b],
815-
outputs=[x],
816-
trans=trans,
817-
b_ndim=b_ndim,
818-
check_finite=check_finite,
819-
)(LU, pivots, b)
803+
return x
820804

821805

822806
class SolveTriangular(SolveBase):

0 commit comments

Comments
 (0)