Skip to content

Commit 86c5539

Browse files
Add LU decomposition Op
1 parent 2460f2d commit 86c5539

File tree

2 files changed

+246
-3
lines changed

2 files changed

+246
-3
lines changed

pytensor/tensor/slinalg.py

Lines changed: 184 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import logging
22
import typing
33
import warnings
4+
from collections.abc import Sequence
45
from functools import reduce
56
from typing import Literal, cast
67

78
import numpy as np
8-
import scipy.linalg
9+
import scipy
910

1011
import pytensor
1112
import pytensor.tensor as pt
13+
from pytensor import Variable
14+
from pytensor.gradient import DisconnectedType
1215
from pytensor.graph.basic import Apply
1316
from pytensor.graph.op import Op
1417
from pytensor.tensor import TensorLike, as_tensor_variable
@@ -25,8 +28,6 @@
2528

2629

2730
class Cholesky(Op):
28-
# TODO: LAPACK wrapper with in-place behavior, for solve also
29-
3031
__props__ = ("lower", "check_finite", "on_error", "overwrite_a")
3132
gufunc_signature = "(m,m)->(m,m)"
3233

@@ -396,6 +397,186 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
396397
)(A, b)
397398

398399

400+
class LU(Op):
401+
"""Decompose a matrix into lower and upper triangular matrices."""
402+
403+
__props__ = ("permute_l", "overwrite_a", "check_finite", "p_indices")
404+
405+
def __init__(
406+
self, *, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False
407+
):
408+
self.permute_l = permute_l
409+
self.check_finite = check_finite
410+
self.p_indices = p_indices
411+
self.overwrite_a = overwrite_a
412+
413+
if self.permute_l:
414+
# permute_l overrides p_indices in the scipy function. We can copy that behavior
415+
self.gufunc_signature = "(m,m)->(m,m),(m,m)"
416+
elif self.p_indices:
417+
self.gufunc_signature = "(m,m)->(m),(m,m),(m,m)"
418+
else:
419+
self.gufunc_signature = "(m,m)->(m,m),(m,m),(m,m)"
420+
421+
if self.overwrite_a:
422+
self.destroy_map = {0: [0]}
423+
424+
def infer_shape(self, fgraph, node, shapes):
425+
n = shapes[0][0]
426+
if self.permute_l:
427+
return [(n, n), (n, n)]
428+
elif self.p_indices:
429+
return [(n,), (n, n), (n, n)]
430+
else:
431+
return [(n, n), (n, n), (n, n)]
432+
433+
def make_node(self, x):
434+
x = as_tensor_variable(x)
435+
if x.type.ndim != 2:
436+
raise TypeError(
437+
f"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
438+
)
439+
440+
real_dtype = "f" if np.dtype(x.type.dtype).char in "fF" else "d"
441+
p_dtype = "int32" if self.p_indices else np.dtype(real_dtype)
442+
443+
L = tensor(shape=x.type.shape, dtype=real_dtype)
444+
U = tensor(shape=x.type.shape, dtype=real_dtype)
445+
446+
if self.permute_l:
447+
# In this case, L is actually P @ L
448+
return Apply(self, inputs=[x], outputs=[L, U])
449+
elif self.p_indices:
450+
p = tensor(shape=(x.type.shape[0],), dtype=p_dtype)
451+
return Apply(self, inputs=[x], outputs=[p, L, U])
452+
else:
453+
P = tensor(shape=x.type.shape, dtype=p_dtype)
454+
return Apply(self, inputs=[x], outputs=[P, L, U])
455+
456+
def perform(self, node, inputs, outputs):
457+
[A] = inputs
458+
459+
out = scipy.linalg.lu(
460+
A,
461+
permute_l=self.permute_l,
462+
overwrite_a=self.overwrite_a,
463+
check_finite=self.check_finite,
464+
p_indices=self.p_indices,
465+
)
466+
467+
outputs[0][0] = out[0]
468+
outputs[1][0] = out[1]
469+
470+
if not self.permute_l:
471+
# In all cases except permute_l, there are three returns
472+
outputs[2][0] = out[2]
473+
474+
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
475+
if 0 in allowed_inplace_inputs:
476+
new_props = self._props_dict() # type: ignore
477+
new_props["overwrite_a"] = True
478+
return type(self)(**new_props)
479+
else:
480+
return self
481+
482+
def L_op(
483+
self,
484+
inputs: Sequence[Variable],
485+
outputs: Sequence[Variable],
486+
output_grads: Sequence[Variable],
487+
) -> list[Variable]:
488+
r"""
489+
Derivation is due to Differentiation of Matrix Functionals Using Triangular Factorization
490+
F. R. De Hoog, R.S. Anderssen, M. A. Lukas
491+
"""
492+
[A] = inputs
493+
A = cast(TensorVariable, A)
494+
495+
if self.permute_l:
496+
PL_bar, U_bar = output_grads
497+
498+
# TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient
499+
P, L, U = lu( # type: ignore
500+
A, permute_l=False, check_finite=self.check_finite, p_indices=False
501+
)
502+
503+
# Permutation matrix is orthogonal
504+
L_bar = (
505+
P.T @ PL_bar
506+
if not isinstance(PL_bar.type, DisconnectedType)
507+
else pt.zeros_like(A)
508+
)
509+
510+
elif self.p_indices:
511+
p, L, U = outputs
512+
513+
# TODO: rewrite to p_indices = False for graphs where we need to compute the gradient
514+
P = pt.eye(A.shape[0])[p]
515+
_, L_bar, U_bar = output_grads
516+
else:
517+
P, L, U = outputs
518+
_, L_bar, U_bar = output_grads
519+
520+
L_bar = (
521+
L_bar if not isinstance(L_bar.type, DisconnectedType) else pt.zeros_like(A)
522+
)
523+
U_bar = (
524+
U_bar if not isinstance(U_bar.type, DisconnectedType) else pt.zeros_like(A)
525+
)
526+
527+
x1 = ptb.tril(L.T @ L_bar, k=-1)
528+
x2 = ptb.triu(U_bar @ U.T)
529+
530+
L_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True)
531+
A_bar = P @ solve_triangular(U, L_inv_x.T, lower=False).T
532+
533+
return [A_bar]
534+
535+
536+
def lu(
537+
a: TensorLike, permute_l=False, check_finite=True, p_indices=False
538+
) -> (
539+
tuple[TensorVariable, TensorVariable, TensorVariable]
540+
| tuple[TensorVariable, TensorVariable]
541+
):
542+
"""
543+
Factorize a matrix as the product of a unit lower triangular matrix and an upper triangular matrix:
544+
545+
... math::
546+
547+
A = P L U
548+
549+
Where P is a permutation matrix, L is lower triangular with unit diagonal elements, and U is upper triangular.
550+
551+
Parameters
552+
----------
553+
a: TensorLike
554+
Matrix to be factorized
555+
permute_l: bool
556+
If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will
557+
be returned in this case, and PL will not be lower triangular.
558+
check_finite: bool
559+
Whether to check that the input matrix contains only finite numbers.
560+
p_indices: bool
561+
If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix
562+
itself.
563+
564+
Returns
565+
-------
566+
P: TensorVariable
567+
Permutation matrix, or array of integer indices for permutation matrix. Not returned if permute_l is True.
568+
L: TensorVariable
569+
Lower triangular matrix, or product of permutation and unit lower triangular matrices if permute_l is True.
570+
U: TensorVariable
571+
Upper triangular matrix
572+
"""
573+
return cast(
574+
tuple[TensorVariable, TensorVariable, TensorVariable]
575+
| tuple[TensorVariable, TensorVariable],
576+
LU(permute_l=permute_l, check_finite=check_finite, p_indices=p_indices)(a),
577+
)
578+
579+
399580
class SolveTriangular(SolveBase):
400581
"""Solve a system of linear equations."""
401582

tests/tensor/test_slinalg.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
cholesky,
2222
eigvalsh,
2323
expm,
24+
lu,
2425
solve,
2526
solve_continuous_lyapunov,
2627
solve_discrete_are,
@@ -437,6 +438,67 @@ def test_solve_dtype(self):
437438
assert x.dtype == x_result.dtype, (A_dtype, b_dtype)
438439

439440

441+
@pytest.mark.parametrize("permute_l", [True, False], ids=["permute_l", "no_permute_l"])
442+
@pytest.mark.parametrize("p_indices", [True, False], ids=["p_indices", "no_p_indices"])
443+
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
444+
def test_lu_decomposition(permute_l, p_indices, complex):
445+
dtype = config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}"
446+
A = tensor("A", shape=(None, None), dtype=dtype)
447+
out = lu(A, permute_l=permute_l, p_indices=p_indices)
448+
449+
f = pytensor.function([A], out)
450+
451+
rng = np.random.default_rng(utt.fetch_seed())
452+
x = rng.normal(size=(5, 5)).astype(config.floatX)
453+
if complex:
454+
x = x + 1j * rng.normal(size=(5, 5)).astype(config.floatX)
455+
456+
out = f(x)
457+
458+
if permute_l:
459+
PL, U = out
460+
x_rebuilt = PL @ U
461+
elif p_indices:
462+
p, L, U = out
463+
P = np.eye(5)[p]
464+
x_rebuilt = P @ L @ U
465+
else:
466+
P, L, U = out
467+
x_rebuilt = P @ L @ U
468+
469+
np.testing.assert_allclose(x, x_rebuilt)
470+
scipy_out = scipy.linalg.lu(x, permute_l=permute_l, p_indices=p_indices)
471+
472+
for a, b in zip(out, scipy_out, strict=True):
473+
np.testing.assert_allclose(a, b)
474+
475+
476+
@pytest.mark.parametrize("grad_case", [0, 1, 2], ids=["U_only", "L_only", "U_and_L"])
477+
@pytest.mark.parametrize("permute_l", [True, False])
478+
@pytest.mark.parametrize("p_indices", [True, False])
479+
def test_lu_grad(grad_case, permute_l, p_indices):
480+
rng = np.random.default_rng(utt.fetch_seed())
481+
A_value = rng.normal(size=(5, 5))
482+
483+
def f_pt(A):
484+
out = lu(A, permute_l=permute_l, p_indices=p_indices)
485+
486+
if permute_l:
487+
L, U = out
488+
else:
489+
_, L, U = out
490+
491+
match grad_case:
492+
case 0:
493+
return U.sum()
494+
case 1:
495+
return L.sum()
496+
case 2:
497+
return U.sum() + L.sum()
498+
499+
utt.verify_grad(f_pt, [A_value], rng=rng)
500+
501+
440502
def test_cho_solve():
441503
rng = np.random.default_rng(utt.fetch_seed())
442504
A = matrix()

0 commit comments

Comments
 (0)