Skip to content

Commit 91815b3

Browse files
Add lu_solve function
1 parent 8a7f75d commit 91815b3

File tree

3 files changed

+241
-46
lines changed

3 files changed

+241
-46
lines changed

pytensor/tensor/slinalg.py

Lines changed: 121 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def __init__(
226226
):
227227
self.lower = lower
228228
self.check_finite = check_finite
229+
229230
assert b_ndim in (1, 2)
230231
self.b_ndim = b_ndim
231232
if b_ndim == 1:
@@ -303,10 +304,14 @@ def L_op(self, inputs, outputs, output_gradients):
303304

304305
solve_op = type(self)(**props_dict)
305306

306-
b_bar = solve_op(A.T, c_bar)
307+
b_bar = solve_op(A.mT, c_bar)
307308
# force outer product if vector second input
308309
A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
309310

311+
if props_dict.get("unit_diagonal", False):
312+
n = A_bar.shape[-1]
313+
A_bar = A_bar[pt.arange(n), pt.arange(n)].set(pt.zeros(n))
314+
310315
return [A_bar, b_bar]
311316

312317

@@ -577,12 +582,42 @@ def lu(
577582
)
578583

579584

585+
def _pivot_to_permutation(pivots):
586+
"""
587+
Converts a sequence of row exchanges to a permutation matrix that represents the same row exchanges. This
588+
represents the inverse permutation, which can be used to reconstruct the original matrix from its LU factorization.
589+
To get the actual permutation, the inverse permutation must be argsorted.
590+
"""
591+
592+
def step(i, permutation, swaps):
593+
j = swaps[i]
594+
x = permutation[i]
595+
y = permutation[j]
596+
597+
permutation = permutation[i].set(y)
598+
return permutation[j].set(x)
599+
600+
pivots = as_tensor_variable(pivots)
601+
n = pivots.shape[0]
602+
p_inv, _ = pytensor.scan(
603+
step,
604+
sequences=[pt.arange(n.copy())],
605+
outputs_info=[pt.arange(n.copy())],
606+
non_sequences=[pivots],
607+
)
608+
609+
return p_inv[-1]
610+
611+
580612
class LUFactor(Op):
581-
__props__ = ("overwrite_a", "check_finite")
613+
__props__ = ("overwrite_a", "check_finite", "permutation_indices")
582614

583-
def __init__(self, *, overwrite_a=False, check_finite=True):
615+
def __init__(
616+
self, *, overwrite_a=False, check_finite=True, permutation_indices=False
617+
):
584618
self.overwrite_a = overwrite_a
585619
self.check_finite = check_finite
620+
self.permutation_indices = permutation_indices
586621
self.gufunc_signature = "(m,m)->(m,m),(m)"
587622

588623
if self.overwrite_a:
@@ -596,8 +631,9 @@ def make_node(self, A):
596631
)
597632

598633
LU = matrix(shape=A.type.shape, dtype=A.type.dtype)
599-
pivots = vector(shape=(A.type.shape[0],), dtype="int32")
600-
return Apply(self, [A], [LU, pivots])
634+
pivots_or_permutations = vector(shape=(A.type.shape[0],), dtype="int32")
635+
636+
return Apply(self, [A], [LU, pivots_or_permutations])
601637

602638
def infer_shape(self, fgraph, node, shapes):
603639
n = shapes[0][0]
@@ -613,25 +649,40 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
613649

614650
def perform(self, node, inputs, outputs):
615651
A = inputs[0]
616-
LU, pivots = scipy_linalg.lu_factor(
617-
A,
618-
overwrite_a=self.overwrite_a,
619-
check_finite=self.check_finite,
620-
)
652+
653+
if self.permutation_indices:
654+
p, L, U = cast(
655+
tuple[np.ndarray, np.ndarray, np.ndarray],
656+
scipy_linalg.lu(
657+
A,
658+
overwrite_a=self.overwrite_a,
659+
check_finite=self.check_finite,
660+
p_indices=True,
661+
permute_l=False,
662+
),
663+
)
664+
LU = np.tril(L, k=-1) + U
665+
666+
else:
667+
LU, p = scipy_linalg.lu_factor(
668+
A, overwrite_a=self.overwrite_a, check_finite=self.check_finite
669+
)
621670

622671
outputs[0][0] = LU
623-
outputs[1][0] = pivots
672+
outputs[1][0] = p
624673

625674
def L_op(self, inputs, outputs, output_gradients):
626-
A = inputs[0]
675+
[A] = inputs
627676
LU_bar, _ = output_gradients
677+
LU, p_indices = outputs
628678

629-
# We need the permutation matrix P, not the pivot indices. Easiest way is to just do another LU forward.
630-
# Alternative is to do a scan over the pivot indices to convert them to permutation indices. I don't know if
631-
# that's faster or slower.
632-
P, L, U = lu(
633-
A, permute_l=False, check_finite=self.check_finite, p_indices=False
634-
)
679+
eye = ptb.identity_like(A)
680+
L = cast(TensorVariable, ptb.tril(LU, k=-1) + eye)
681+
U = cast(TensorVariable, ptb.triu(LU))
682+
683+
if not self.permutation_indices:
684+
p_indices_inv = _pivot_to_permutation(cast(TensorVariable, p_indices))
685+
p_indices = pt.argsort(p_indices_inv)
635686

636687
# Split LU_bar into L_bar and U_bar. This is valid because of the triangular structure of L and U
637688
L_bar = ptb.tril(LU_bar, k=-1)
@@ -642,13 +693,14 @@ def L_op(self, inputs, outputs, output_gradients):
642693
x2 = ptb.triu(U_bar @ U.T)
643694

644695
LT_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True)
645-
A_bar = P @ solve_triangular(U, LT_inv_x.T, lower=False).T
696+
B_bar = solve_triangular(U, LT_inv_x.T, lower=False).T
697+
A_bar = B_bar[p_indices]
646698

647699
return [A_bar]
648700

649701

650702
def lu_factor(
651-
a: TensorLike, *, check_finite=True
703+
a: TensorLike, *, check_finite: bool = True, permutation_indices: bool = False
652704
) -> tuple[TensorVariable, TensorVariable]:
653705
"""
654706
LU factorization with partial pivoting.
@@ -659,21 +711,63 @@ def lu_factor(
659711
Matrix to be factorized
660712
check_finite: bool
661713
Whether to check that the input matrix contains only finite numbers.
714+
permutation_indices: bool
715+
If True, returns permutation indices such that L[p] @ U = A. Otherwise returns the pivot indices, which give
716+
a record of row swaps that occured at each iteration of the LU factorization. Default is False, which matches
717+
the behavior of scipy.linalg.lu_factor.
662718
663719
Returns
664720
-------
665721
LU: TensorVariable
666722
LU decomposition of `a`
667-
pivots: TensorVariable
668-
Permutation indices
723+
pivots_or_permutations: TensorVariable
724+
An array of integers representing either the pivot indices or permutation indices, depending on the value of
725+
`permutation_indices`.
669726
"""
670727

671728
return cast(
672729
tuple[TensorVariable, TensorVariable],
673-
Blockwise(LUFactor(check_finite=check_finite))(a),
730+
Blockwise(
731+
LUFactor(check_finite=check_finite, permutation_indices=permutation_indices)
732+
)(a),
674733
)
675734

676735

736+
def lu_solve(
737+
LU_and_pivots: tuple[TensorVariable, TensorVariable],
738+
b: TensorVariable,
739+
trans=False,
740+
b_ndim=None,
741+
check_finite=True,
742+
):
743+
LU, pivots = LU_and_pivots
744+
inv_permutation = _pivot_to_permutation(pivots)
745+
746+
x = b[inv_permutation] if not trans else b
747+
748+
x = solve_triangular(
749+
LU,
750+
x,
751+
lower=not trans,
752+
unit_diagonal=not trans,
753+
trans=trans,
754+
b_ndim=b_ndim,
755+
check_finite=check_finite,
756+
)
757+
758+
x = solve_triangular(
759+
LU,
760+
x,
761+
lower=trans,
762+
unit_diagonal=trans,
763+
trans=trans,
764+
b_ndim=b_ndim,
765+
check_finite=check_finite,
766+
)
767+
768+
return x[pt.argsort(inv_permutation)] if trans else x
769+
770+
677771
class SolveTriangular(SolveBase):
678772
"""Solve a system of linear equations."""
679773

@@ -688,6 +782,9 @@ class SolveTriangular(SolveBase):
688782
def __init__(self, *, unit_diagonal=False, **kwargs):
689783
if kwargs.get("overwrite_a", False):
690784
raise ValueError("overwrite_a is not supported for SolverTriangulare")
785+
786+
# There's a naming inconsistency between solve_triangular (trans) and solve (transposed). Internally, we can use
787+
# transpose everywhere, but expose the same API as scipy.linalg.solve_triangular
691788
super().__init__(**kwargs)
692789
self.unit_diagonal = unit_diagonal
693790

@@ -1546,4 +1643,5 @@ def block_diag(*matrices: TensorVariable):
15461643
"cho_solve",
15471644
"lu",
15481645
"lu_factor",
1646+
"lu_solve",
15491647
]

tests/link/numba/test_slinalg.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pytest
77
from numpy.testing import assert_allclose
8+
from scipy import linalg as scipy_linalg
89

910
import pytensor
1011
import pytensor.tensor as pt
@@ -266,15 +267,13 @@ def test_block_diag():
266267

267268

268269
def test_lamch():
269-
from scipy.linalg import get_lapack_funcs
270-
271270
from pytensor.link.numba.dispatch.slinalg import _xlamch
272271

273272
@numba.njit()
274273
def xlamch(kind):
275274
return _xlamch(kind)
276275

277-
lamch = get_lapack_funcs("lamch", (np.array([0.0], dtype=floatX),))
276+
lamch = scipy_linalg.get_lapack_funcs("lamch", (np.array([0.0], dtype=floatX),))
278277

279278
np.testing.assert_allclose(xlamch("E"), lamch("E"))
280279
np.testing.assert_allclose(xlamch("S"), lamch("S"))
@@ -289,23 +288,19 @@ def xlamch(kind):
289288
)
290289
def test_xlange(ord_numba, ord_scipy):
291290
# xlange is called internally only, we don't dispatch pt.linalg.norm to it
292-
from scipy import linalg
293-
294291
from pytensor.link.numba.dispatch.slinalg import _xlange
295292

296293
@numba.njit()
297294
def xlange(x, ord):
298295
return _xlange(x, ord)
299296

300297
x = np.random.normal(size=(5, 5)).astype(floatX)
301-
np.testing.assert_allclose(xlange(x, ord_numba), linalg.norm(x, ord_scipy))
298+
np.testing.assert_allclose(xlange(x, ord_numba), scipy_linalg.norm(x, ord_scipy))
302299

303300

304301
@pytest.mark.parametrize("ord_numba, ord_scipy", [("1", 1), ("I", np.inf)])
305302
def test_xgecon(ord_numba, ord_scipy):
306303
# gecon is called internally only, we don't dispatch pt.linalg.norm to it
307-
from scipy.linalg import get_lapack_funcs
308-
309304
from pytensor.link.numba.dispatch.slinalg import _xgecon, _xlange
310305

311306
@numba.njit()
@@ -320,7 +315,7 @@ def gecon(x, norm):
320315

321316
# Test against direct call to the underlying LAPACK functions
322317
# Solution does **not** agree with 1 / np.linalg.cond(x) !
323-
lange, gecon = get_lapack_funcs(("lange", "gecon"), (x,))
318+
lange, gecon = scipy_linalg.get_lapack_funcs(("lange", "gecon"), (x,))
324319
norm = lange(ord_numba, x)
325320
rcond2, _ = gecon(x, norm, norm=ord_numba)
326321

@@ -330,8 +325,6 @@ def gecon(x, norm):
330325

331326
@pytest.mark.parametrize("overwrite_a", [True, False])
332327
def test_getrf(overwrite_a):
333-
from scipy.linalg import lu_factor
334-
335328
from pytensor.link.numba.dispatch.slinalg import _getrf
336329

337330
# TODO: Refactor this test to use compare_numba_and_py after we implement lu_factor in pytensor
@@ -345,7 +338,7 @@ def getrf(x, overwrite_a):
345338
x
346339
) # x needs to be fortran-contiguous going into getrf for the overwrite option to work
347340

348-
lu, ipiv = lu_factor(x, overwrite_a=False)
341+
lu, ipiv = scipy_linalg.lu_factor(x, overwrite_a=False)
349342
LU, IPIV, info = getrf(x, overwrite_a=overwrite_a)
350343

351344
assert info == 0
@@ -364,9 +357,6 @@ def getrf(x, overwrite_a):
364357
@pytest.mark.parametrize("overwrite_b", [True, False])
365358
@pytest.mark.parametrize("b_shape", [(5,), (5, 3)], ids=["b_1d", "b_2d"])
366359
def test_getrs(trans, overwrite_a, overwrite_b, b_shape):
367-
from scipy.linalg import lu_factor
368-
from scipy.linalg import lu_solve as sp_lu_solve
369-
370360
from pytensor.link.numba.dispatch.slinalg import _getrf, _getrs
371361

372362
# TODO: Refactor this test to use compare_numba_and_py after we implement lu_solve in pytensor
@@ -384,8 +374,8 @@ def lu_solve(a, b, trans, overwrite_a, overwrite_b):
384374
a = np.asfortranarray(a)
385375
b = np.asfortranarray(b)
386376

387-
lu_and_piv = lu_factor(a, overwrite_a=False)
388-
x_sp = sp_lu_solve(lu_and_piv, b, trans, overwrite_b=False)
377+
lu_and_piv = scipy_linalg.lu_factor(a, overwrite_a=False)
378+
x_sp = scipy_linalg.lu_solve(lu_and_piv, b, trans, overwrite_b=False)
389379

390380
x, lu, info = lu_solve(
391381
a, b, trans, overwrite_a=overwrite_a, overwrite_b=overwrite_b

0 commit comments

Comments
 (0)