Skip to content

Commit a532c48

Browse files
Add LUSolve Op
Add lu_solve function
1 parent 65082f2 commit a532c48

File tree

3 files changed

+355
-41
lines changed

3 files changed

+355
-41
lines changed

pytensor/tensor/slinalg.py

Lines changed: 171 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import pytensor
1212
import pytensor.tensor as pt
13+
from pytensor.compile.builders import OpFromGraph
1314
from pytensor.gradient import DisconnectedType
1415
from pytensor.graph.basic import Apply
1516
from pytensor.graph.op import Op
@@ -226,6 +227,7 @@ def __init__(
226227
):
227228
self.lower = lower
228229
self.check_finite = check_finite
230+
229231
assert b_ndim in (1, 2)
230232
self.b_ndim = b_ndim
231233
if b_ndim == 1:
@@ -303,10 +305,14 @@ def L_op(self, inputs, outputs, output_gradients):
303305

304306
solve_op = type(self)(**props_dict)
305307

306-
b_bar = solve_op(A.T, c_bar)
308+
b_bar = solve_op(A.mT, c_bar)
307309
# force outer product if vector second input
308310
A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
309311

312+
if props_dict.get("unit_diagonal", False):
313+
n = A_bar.shape[-1]
314+
A_bar = A_bar[pt.arange(n), pt.arange(n)].set(pt.zeros(n))
315+
310316
return [A_bar, b_bar]
311317

312318

@@ -577,12 +583,42 @@ def lu(
577583
)
578584

579585

586+
def _pivot_to_permutation(pivots):
587+
"""
588+
Converts a sequence of row exchanges to a permutation matrix that represents the same row exchanges. This
589+
represents the inverse permutation, which can be used to reconstruct the original matrix from its LU factorization.
590+
To get the actual permutation, the inverse permutation must be argsorted.
591+
"""
592+
593+
def step(i, permutation, swaps):
594+
j = swaps[i]
595+
x = permutation[i]
596+
y = permutation[j]
597+
598+
permutation = permutation[i].set(y)
599+
return permutation[j].set(x)
600+
601+
pivots = as_tensor_variable(pivots)
602+
n = pivots.shape[0]
603+
p_inv, _ = pytensor.scan(
604+
step,
605+
sequences=[pt.arange(n.copy())],
606+
outputs_info=[pt.arange(n.copy())],
607+
non_sequences=[pivots],
608+
)
609+
610+
return p_inv[-1]
611+
612+
580613
class LUFactor(Op):
581-
__props__ = ("overwrite_a", "check_finite")
614+
__props__ = ("overwrite_a", "check_finite", "permutation_indices")
582615

583-
def __init__(self, *, overwrite_a=False, check_finite=True):
616+
def __init__(
617+
self, *, overwrite_a=False, check_finite=True, permutation_indices=False
618+
):
584619
self.overwrite_a = overwrite_a
585620
self.check_finite = check_finite
621+
self.permutation_indices = permutation_indices
586622
self.gufunc_signature = "(m,m)->(m,m),(m)"
587623

588624
if self.overwrite_a:
@@ -596,8 +632,9 @@ def make_node(self, A):
596632
)
597633

598634
LU = matrix(shape=A.type.shape, dtype=A.type.dtype)
599-
pivots = vector(shape=(A.type.shape[0],), dtype="int32")
600-
return Apply(self, [A], [LU, pivots])
635+
pivots_or_permutations = vector(shape=(A.type.shape[0],), dtype="int32")
636+
637+
return Apply(self, [A], [LU, pivots_or_permutations])
601638

602639
def infer_shape(self, fgraph, node, shapes):
603640
n = shapes[0][0]
@@ -613,25 +650,40 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
613650

614651
def perform(self, node, inputs, outputs):
615652
A = inputs[0]
616-
LU, pivots = scipy_linalg.lu_factor(
617-
A,
618-
overwrite_a=self.overwrite_a,
619-
check_finite=self.check_finite,
620-
)
653+
654+
if self.permutation_indices:
655+
p, L, U = cast(
656+
tuple[np.ndarray, np.ndarray, np.ndarray],
657+
scipy_linalg.lu(
658+
A,
659+
overwrite_a=self.overwrite_a,
660+
check_finite=self.check_finite,
661+
p_indices=True,
662+
permute_l=False,
663+
),
664+
)
665+
LU = np.tril(L, k=-1) + U
666+
667+
else:
668+
LU, p = scipy_linalg.lu_factor(
669+
A, overwrite_a=self.overwrite_a, check_finite=self.check_finite
670+
)
621671

622672
outputs[0][0] = LU
623-
outputs[1][0] = pivots
673+
outputs[1][0] = p
624674

625675
def L_op(self, inputs, outputs, output_gradients):
626-
A = inputs[0]
676+
[A] = inputs
627677
LU_bar, _ = output_gradients
678+
LU, p_indices = outputs
628679

629-
# We need the permutation matrix P, not the pivot indices. Easiest way is to just do another LU forward.
630-
# Alternative is to do a scan over the pivot indices to convert them to permutation indices. I don't know if
631-
# that's faster or slower.
632-
P, L, U = lu(
633-
A, permute_l=False, check_finite=self.check_finite, p_indices=False
634-
)
680+
eye = ptb.identity_like(A)
681+
L = cast(TensorVariable, ptb.tril(LU, k=-1) + eye)
682+
U = cast(TensorVariable, ptb.triu(LU))
683+
684+
if not self.permutation_indices:
685+
p_indices_inv = _pivot_to_permutation(cast(TensorVariable, p_indices))
686+
p_indices = pt.argsort(p_indices_inv)
635687

636688
# Split LU_bar into L_bar and U_bar. This is valid because of the triangular structure of L and U
637689
L_bar = ptb.tril(LU_bar, k=-1)
@@ -642,13 +694,14 @@ def L_op(self, inputs, outputs, output_gradients):
642694
x2 = ptb.triu(U_bar @ U.T)
643695

644696
LT_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True)
645-
A_bar = P @ solve_triangular(U, LT_inv_x.T, lower=False).T
697+
B_bar = solve_triangular(U, LT_inv_x.T, lower=False).T
698+
A_bar = B_bar[p_indices]
646699

647700
return [A_bar]
648701

649702

650703
def lu_factor(
651-
a: TensorLike, *, check_finite=True
704+
a: TensorLike, *, check_finite: bool = True, permutation_indices: bool = False
652705
) -> tuple[TensorVariable, TensorVariable]:
653706
"""
654707
LU factorization with partial pivoting.
@@ -659,21 +712,112 @@ def lu_factor(
659712
Matrix to be factorized
660713
check_finite: bool
661714
Whether to check that the input matrix contains only finite numbers.
715+
permutation_indices: bool
716+
If True, returns permutation indices such that L[p] @ U = A. Otherwise returns the pivot indices, which give
717+
a record of row swaps that occured at each iteration of the LU factorization. Default is False, which matches
718+
the behavior of scipy.linalg.lu_factor.
662719
663720
Returns
664721
-------
665722
LU: TensorVariable
666723
LU decomposition of `a`
667-
pivots: TensorVariable
668-
Permutation indices
724+
pivots_or_permutations: TensorVariable
725+
An array of integers representing either the pivot indices or permutation indices, depending on the value of
726+
`permutation_indices`.
669727
"""
670728

671729
return cast(
672730
tuple[TensorVariable, TensorVariable],
673-
Blockwise(LUFactor(check_finite=check_finite))(a),
731+
Blockwise(
732+
LUFactor(check_finite=check_finite, permutation_indices=permutation_indices)
733+
)(a),
674734
)
675735

676736

737+
class LUSolve(OpFromGraph):
738+
"""Solve a system of linear equations given the LU decomposition of the matrix."""
739+
740+
__props__ = ("trans", "b_ndim", "check_finite", "overwrite_b")
741+
742+
def __init__(
743+
self,
744+
*args,
745+
trans: bool = False,
746+
b_ndim: int | None = None,
747+
check_finite: bool = False,
748+
overwrite_b: bool = False,
749+
**kwargs,
750+
):
751+
self.trans = trans
752+
self.b_ndim = b_ndim
753+
self.check_finite = check_finite
754+
self.overwrite_b = overwrite_b
755+
756+
super().__init__(*args, **kwargs)
757+
758+
759+
def lu_solve(
760+
LU_and_pivots: tuple[TensorLike, TensorLike],
761+
b: TensorLike,
762+
trans: bool = False,
763+
b_ndim: int | None = None,
764+
check_finite: bool = True,
765+
):
766+
"""
767+
Solve a system of linear equations given the LU decomposition of the matrix.
768+
769+
Parameters
770+
----------
771+
LU_and_pivots: tuple[TensorLike, TensorLike]
772+
LU decomposition of the matrix, as returned by `lu_factor`
773+
b: TensorLike
774+
Right-hand side of the equation
775+
trans: bool
776+
If True, solve A^T x = b, instead of Ax = b. Default is False
777+
b_ndim: int, optional
778+
The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix
779+
of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input.
780+
check_finite: bool
781+
If True, check that the input matrices contain only finite numbers. Default is True.
782+
"""
783+
b_ndim = _default_b_ndim(b, b_ndim)
784+
LU, pivots = LU_and_pivots
785+
786+
LU, pivots, b = map(pt.as_tensor_variable, [LU, pivots, b])
787+
inv_permutation = _pivot_to_permutation(pivots)
788+
789+
x = b[inv_permutation] if not trans else b
790+
791+
x = solve_triangular(
792+
LU,
793+
x,
794+
lower=not trans,
795+
unit_diagonal=not trans,
796+
trans=trans,
797+
b_ndim=b_ndim,
798+
check_finite=check_finite,
799+
)
800+
801+
x = solve_triangular(
802+
LU,
803+
x,
804+
lower=trans,
805+
unit_diagonal=trans,
806+
trans=trans,
807+
b_ndim=b_ndim,
808+
check_finite=check_finite,
809+
)
810+
x = x[pt.argsort(inv_permutation)] if trans else x
811+
812+
return LUSolve(
813+
inputs=[LU, pivots, b],
814+
outputs=[x],
815+
trans=trans,
816+
b_ndim=b_ndim,
817+
check_finite=check_finite,
818+
)(LU, pivots, b)
819+
820+
677821
class SolveTriangular(SolveBase):
678822
"""Solve a system of linear equations."""
679823

@@ -688,6 +832,9 @@ class SolveTriangular(SolveBase):
688832
def __init__(self, *, unit_diagonal=False, **kwargs):
689833
if kwargs.get("overwrite_a", False):
690834
raise ValueError("overwrite_a is not supported for SolverTriangulare")
835+
836+
# There's a naming inconsistency between solve_triangular (trans) and solve (transposed). Internally, we can use
837+
# transpose everywhere, but expose the same API as scipy.linalg.solve_triangular
691838
super().__init__(**kwargs)
692839
self.unit_diagonal = unit_diagonal
693840

@@ -1546,4 +1693,5 @@ def block_diag(*matrices: TensorVariable):
15461693
"cho_solve",
15471694
"lu",
15481695
"lu_factor",
1696+
"lu_solve",
15491697
]

0 commit comments

Comments
 (0)