Skip to content

Commit 3c32e9e

Browse files
Add LUSolve Op
Add lu_solve function
1 parent 8a7f75d commit 3c32e9e

File tree

3 files changed

+289
-46
lines changed

3 files changed

+289
-46
lines changed

pytensor/tensor/slinalg.py

Lines changed: 171 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import pytensor
1212
import pytensor.tensor as pt
13+
from pytensor.compile.builders import OpFromGraph
1314
from pytensor.gradient import DisconnectedType
1415
from pytensor.graph.basic import Apply
1516
from pytensor.graph.op import Op
@@ -226,6 +227,7 @@ def __init__(
226227
):
227228
self.lower = lower
228229
self.check_finite = check_finite
230+
229231
assert b_ndim in (1, 2)
230232
self.b_ndim = b_ndim
231233
if b_ndim == 1:
@@ -303,10 +305,14 @@ def L_op(self, inputs, outputs, output_gradients):
303305

304306
solve_op = type(self)(**props_dict)
305307

306-
b_bar = solve_op(A.T, c_bar)
308+
b_bar = solve_op(A.mT, c_bar)
307309
# force outer product if vector second input
308310
A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
309311

312+
if props_dict.get("unit_diagonal", False):
313+
n = A_bar.shape[-1]
314+
A_bar = A_bar[pt.arange(n), pt.arange(n)].set(pt.zeros(n))
315+
310316
return [A_bar, b_bar]
311317

312318

@@ -577,12 +583,42 @@ def lu(
577583
)
578584

579585

586+
def _pivot_to_permutation(pivots):
587+
"""
588+
Converts a sequence of row exchanges to a permutation matrix that represents the same row exchanges. This
589+
represents the inverse permutation, which can be used to reconstruct the original matrix from its LU factorization.
590+
To get the actual permutation, the inverse permutation must be argsorted.
591+
"""
592+
593+
def step(i, permutation, swaps):
594+
j = swaps[i]
595+
x = permutation[i]
596+
y = permutation[j]
597+
598+
permutation = permutation[i].set(y)
599+
return permutation[j].set(x)
600+
601+
pivots = as_tensor_variable(pivots)
602+
n = pivots.shape[0]
603+
p_inv, _ = pytensor.scan(
604+
step,
605+
sequences=[pt.arange(n.copy())],
606+
outputs_info=[pt.arange(n.copy())],
607+
non_sequences=[pivots],
608+
)
609+
610+
return p_inv[-1]
611+
612+
580613
class LUFactor(Op):
581-
__props__ = ("overwrite_a", "check_finite")
614+
__props__ = ("overwrite_a", "check_finite", "permutation_indices")
582615

583-
def __init__(self, *, overwrite_a=False, check_finite=True):
616+
def __init__(
617+
self, *, overwrite_a=False, check_finite=True, permutation_indices=False
618+
):
584619
self.overwrite_a = overwrite_a
585620
self.check_finite = check_finite
621+
self.permutation_indices = permutation_indices
586622
self.gufunc_signature = "(m,m)->(m,m),(m)"
587623

588624
if self.overwrite_a:
@@ -596,8 +632,9 @@ def make_node(self, A):
596632
)
597633

598634
LU = matrix(shape=A.type.shape, dtype=A.type.dtype)
599-
pivots = vector(shape=(A.type.shape[0],), dtype="int32")
600-
return Apply(self, [A], [LU, pivots])
635+
pivots_or_permutations = vector(shape=(A.type.shape[0],), dtype="int32")
636+
637+
return Apply(self, [A], [LU, pivots_or_permutations])
601638

602639
def infer_shape(self, fgraph, node, shapes):
603640
n = shapes[0][0]
@@ -613,25 +650,40 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
613650

614651
def perform(self, node, inputs, outputs):
615652
A = inputs[0]
616-
LU, pivots = scipy_linalg.lu_factor(
617-
A,
618-
overwrite_a=self.overwrite_a,
619-
check_finite=self.check_finite,
620-
)
653+
654+
if self.permutation_indices:
655+
p, L, U = cast(
656+
tuple[np.ndarray, np.ndarray, np.ndarray],
657+
scipy_linalg.lu(
658+
A,
659+
overwrite_a=self.overwrite_a,
660+
check_finite=self.check_finite,
661+
p_indices=True,
662+
permute_l=False,
663+
),
664+
)
665+
LU = np.tril(L, k=-1) + U
666+
667+
else:
668+
LU, p = scipy_linalg.lu_factor(
669+
A, overwrite_a=self.overwrite_a, check_finite=self.check_finite
670+
)
621671

622672
outputs[0][0] = LU
623-
outputs[1][0] = pivots
673+
outputs[1][0] = p
624674

625675
def L_op(self, inputs, outputs, output_gradients):
626-
A = inputs[0]
676+
[A] = inputs
627677
LU_bar, _ = output_gradients
678+
LU, p_indices = outputs
628679

629-
# We need the permutation matrix P, not the pivot indices. Easiest way is to just do another LU forward.
630-
# Alternative is to do a scan over the pivot indices to convert them to permutation indices. I don't know if
631-
# that's faster or slower.
632-
P, L, U = lu(
633-
A, permute_l=False, check_finite=self.check_finite, p_indices=False
634-
)
680+
eye = ptb.identity_like(A)
681+
L = cast(TensorVariable, ptb.tril(LU, k=-1) + eye)
682+
U = cast(TensorVariable, ptb.triu(LU))
683+
684+
if not self.permutation_indices:
685+
p_indices_inv = _pivot_to_permutation(cast(TensorVariable, p_indices))
686+
p_indices = pt.argsort(p_indices_inv)
635687

636688
# Split LU_bar into L_bar and U_bar. This is valid because of the triangular structure of L and U
637689
L_bar = ptb.tril(LU_bar, k=-1)
@@ -642,13 +694,14 @@ def L_op(self, inputs, outputs, output_gradients):
642694
x2 = ptb.triu(U_bar @ U.T)
643695

644696
LT_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True)
645-
A_bar = P @ solve_triangular(U, LT_inv_x.T, lower=False).T
697+
B_bar = solve_triangular(U, LT_inv_x.T, lower=False).T
698+
A_bar = B_bar[p_indices]
646699

647700
return [A_bar]
648701

649702

650703
def lu_factor(
651-
a: TensorLike, *, check_finite=True
704+
a: TensorLike, *, check_finite: bool = True, permutation_indices: bool = False
652705
) -> tuple[TensorVariable, TensorVariable]:
653706
"""
654707
LU factorization with partial pivoting.
@@ -659,21 +712,112 @@ def lu_factor(
659712
Matrix to be factorized
660713
check_finite: bool
661714
Whether to check that the input matrix contains only finite numbers.
715+
permutation_indices: bool
716+
If True, returns permutation indices such that L[p] @ U = A. Otherwise returns the pivot indices, which give
717+
a record of row swaps that occured at each iteration of the LU factorization. Default is False, which matches
718+
the behavior of scipy.linalg.lu_factor.
662719
663720
Returns
664721
-------
665722
LU: TensorVariable
666723
LU decomposition of `a`
667-
pivots: TensorVariable
668-
Permutation indices
724+
pivots_or_permutations: TensorVariable
725+
An array of integers representing either the pivot indices or permutation indices, depending on the value of
726+
`permutation_indices`.
669727
"""
670728

671729
return cast(
672730
tuple[TensorVariable, TensorVariable],
673-
Blockwise(LUFactor(check_finite=check_finite))(a),
731+
Blockwise(
732+
LUFactor(check_finite=check_finite, permutation_indices=permutation_indices)
733+
)(a),
674734
)
675735

676736

737+
class LUSolve(OpFromGraph):
738+
"""Solve a system of linear equations given the LU decomposition of the matrix."""
739+
740+
__props__ = ("trans", "b_ndim", "check_finite", "overwrite_b")
741+
742+
def __init__(
743+
self,
744+
*args,
745+
trans: bool = False,
746+
b_ndim: int | None = None,
747+
check_finite: bool = False,
748+
overwrite_b: bool = False,
749+
**kwargs,
750+
):
751+
self.trans = trans
752+
self.b_ndim = b_ndim
753+
self.check_finite = check_finite
754+
self.overwrite_b = overwrite_b
755+
756+
super().__init__(*args, **kwargs)
757+
758+
759+
def lu_solve(
760+
LU_and_pivots: tuple[TensorLike, TensorLike],
761+
b: TensorLike,
762+
trans: bool = False,
763+
b_ndim: int | None = None,
764+
check_finite: bool = True,
765+
):
766+
"""
767+
Solve a system of linear equations given the LU decomposition of the matrix.
768+
769+
Parameters
770+
----------
771+
LU_and_pivots: tuple[TensorLike, TensorLike]
772+
LU decomposition of the matrix, as returned by `lu_factor`
773+
b: TensorLike
774+
Right-hand side of the equation
775+
trans: bool
776+
If True, solve A^T x = b, instead of Ax = b. Default is False
777+
b_ndim: int, optional
778+
The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix
779+
of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input.
780+
check_finite: bool
781+
If True, check that the input matrices contain only finite numbers. Default is True.
782+
"""
783+
b_ndim = _default_b_ndim(b, b_ndim)
784+
LU, pivots = LU_and_pivots
785+
786+
LU, pivots, b = map(pt.as_tensor_variable, [LU, pivots, b])
787+
inv_permutation = _pivot_to_permutation(pivots)
788+
789+
x = b[inv_permutation] if not trans else b
790+
791+
x = solve_triangular(
792+
LU,
793+
x,
794+
lower=not trans,
795+
unit_diagonal=not trans,
796+
trans=trans,
797+
b_ndim=b_ndim,
798+
check_finite=check_finite,
799+
)
800+
801+
x = solve_triangular(
802+
LU,
803+
x,
804+
lower=trans,
805+
unit_diagonal=trans,
806+
trans=trans,
807+
b_ndim=b_ndim,
808+
check_finite=check_finite,
809+
)
810+
x = x[pt.argsort(inv_permutation)] if trans else x
811+
812+
return LUSolve(
813+
inputs=[LU, pivots, b],
814+
outputs=[x],
815+
trans=trans,
816+
b_ndim=b_ndim,
817+
check_finite=check_finite,
818+
)(LU, pivots, b)
819+
820+
677821
class SolveTriangular(SolveBase):
678822
"""Solve a system of linear equations."""
679823

@@ -688,6 +832,9 @@ class SolveTriangular(SolveBase):
688832
def __init__(self, *, unit_diagonal=False, **kwargs):
689833
if kwargs.get("overwrite_a", False):
690834
raise ValueError("overwrite_a is not supported for SolverTriangulare")
835+
836+
# There's a naming inconsistency between solve_triangular (trans) and solve (transposed). Internally, we can use
837+
# transpose everywhere, but expose the same API as scipy.linalg.solve_triangular
691838
super().__init__(**kwargs)
692839
self.unit_diagonal = unit_diagonal
693840

@@ -1546,4 +1693,5 @@ def block_diag(*matrices: TensorVariable):
15461693
"cho_solve",
15471694
"lu",
15481695
"lu_factor",
1696+
"lu_solve",
15491697
]

tests/link/numba/test_slinalg.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pytest
77
from numpy.testing import assert_allclose
8+
from scipy import linalg as scipy_linalg
89

910
import pytensor
1011
import pytensor.tensor as pt
@@ -266,15 +267,13 @@ def test_block_diag():
266267

267268

268269
def test_lamch():
269-
from scipy.linalg import get_lapack_funcs
270-
271270
from pytensor.link.numba.dispatch.slinalg import _xlamch
272271

273272
@numba.njit()
274273
def xlamch(kind):
275274
return _xlamch(kind)
276275

277-
lamch = get_lapack_funcs("lamch", (np.array([0.0], dtype=floatX),))
276+
lamch = scipy_linalg.get_lapack_funcs("lamch", (np.array([0.0], dtype=floatX),))
278277

279278
np.testing.assert_allclose(xlamch("E"), lamch("E"))
280279
np.testing.assert_allclose(xlamch("S"), lamch("S"))
@@ -289,23 +288,19 @@ def xlamch(kind):
289288
)
290289
def test_xlange(ord_numba, ord_scipy):
291290
# xlange is called internally only, we don't dispatch pt.linalg.norm to it
292-
from scipy import linalg
293-
294291
from pytensor.link.numba.dispatch.slinalg import _xlange
295292

296293
@numba.njit()
297294
def xlange(x, ord):
298295
return _xlange(x, ord)
299296

300297
x = np.random.normal(size=(5, 5)).astype(floatX)
301-
np.testing.assert_allclose(xlange(x, ord_numba), linalg.norm(x, ord_scipy))
298+
np.testing.assert_allclose(xlange(x, ord_numba), scipy_linalg.norm(x, ord_scipy))
302299

303300

304301
@pytest.mark.parametrize("ord_numba, ord_scipy", [("1", 1), ("I", np.inf)])
305302
def test_xgecon(ord_numba, ord_scipy):
306303
# gecon is called internally only, we don't dispatch pt.linalg.norm to it
307-
from scipy.linalg import get_lapack_funcs
308-
309304
from pytensor.link.numba.dispatch.slinalg import _xgecon, _xlange
310305

311306
@numba.njit()
@@ -320,7 +315,7 @@ def gecon(x, norm):
320315

321316
# Test against direct call to the underlying LAPACK functions
322317
# Solution does **not** agree with 1 / np.linalg.cond(x) !
323-
lange, gecon = get_lapack_funcs(("lange", "gecon"), (x,))
318+
lange, gecon = scipy_linalg.get_lapack_funcs(("lange", "gecon"), (x,))
324319
norm = lange(ord_numba, x)
325320
rcond2, _ = gecon(x, norm, norm=ord_numba)
326321

@@ -330,8 +325,6 @@ def gecon(x, norm):
330325

331326
@pytest.mark.parametrize("overwrite_a", [True, False])
332327
def test_getrf(overwrite_a):
333-
from scipy.linalg import lu_factor
334-
335328
from pytensor.link.numba.dispatch.slinalg import _getrf
336329

337330
# TODO: Refactor this test to use compare_numba_and_py after we implement lu_factor in pytensor
@@ -345,7 +338,7 @@ def getrf(x, overwrite_a):
345338
x
346339
) # x needs to be fortran-contiguous going into getrf for the overwrite option to work
347340

348-
lu, ipiv = lu_factor(x, overwrite_a=False)
341+
lu, ipiv = scipy_linalg.lu_factor(x, overwrite_a=False)
349342
LU, IPIV, info = getrf(x, overwrite_a=overwrite_a)
350343

351344
assert info == 0
@@ -364,9 +357,6 @@ def getrf(x, overwrite_a):
364357
@pytest.mark.parametrize("overwrite_b", [True, False])
365358
@pytest.mark.parametrize("b_shape", [(5,), (5, 3)], ids=["b_1d", "b_2d"])
366359
def test_getrs(trans, overwrite_a, overwrite_b, b_shape):
367-
from scipy.linalg import lu_factor
368-
from scipy.linalg import lu_solve as sp_lu_solve
369-
370360
from pytensor.link.numba.dispatch.slinalg import _getrf, _getrs
371361

372362
# TODO: Refactor this test to use compare_numba_and_py after we implement lu_solve in pytensor
@@ -384,8 +374,8 @@ def lu_solve(a, b, trans, overwrite_a, overwrite_b):
384374
a = np.asfortranarray(a)
385375
b = np.asfortranarray(b)
386376

387-
lu_and_piv = lu_factor(a, overwrite_a=False)
388-
x_sp = sp_lu_solve(lu_and_piv, b, trans, overwrite_b=False)
377+
lu_and_piv = scipy_linalg.lu_factor(a, overwrite_a=False)
378+
x_sp = scipy_linalg.lu_solve(lu_and_piv, b, trans, overwrite_b=False)
389379

390380
x, lu, info = lu_solve(
391381
a, b, trans, overwrite_a=overwrite_a, overwrite_b=overwrite_b

0 commit comments

Comments
 (0)