Skip to content

Commit 331e8ab

Browse files
Add LU decomposition Op
1 parent 51ea1a0 commit 331e8ab

File tree

2 files changed

+245
-1
lines changed

2 files changed

+245
-1
lines changed

pytensor/tensor/slinalg.py

Lines changed: 183 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
import warnings
33
from collections.abc import Sequence
44
from functools import reduce
5-
from typing import Literal, cast
5+
from typing import Literal, cast, Sequence
66

77
import numpy as np
88
import scipy.linalg as scipy_linalg
99
from numpy.exceptions import ComplexWarning
1010

1111
import pytensor
1212
import pytensor.tensor as pt
13+
from pytensor.gradient import DisconnectedType
1314
from pytensor.graph.basic import Apply
1415
from pytensor.graph.op import Op
1516
from pytensor.tensor import TensorLike, as_tensor_variable
@@ -303,6 +304,7 @@ def L_op(self, inputs, outputs, output_gradients):
303304
}
304305
)
305306
b_bar = trans_solve_op(A.T, c_bar)
307+
306308
# force outer product if vector second input
307309
A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
308310

@@ -381,6 +383,186 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
381383
)(A, b)
382384

383385

386+
class LU(Op):
387+
"""Decompose a matrix into lower and upper triangular matrices."""
388+
389+
__props__ = ("permute_l", "overwrite_a", "check_finite", "p_indices")
390+
391+
def __init__(
392+
self, *, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False
393+
):
394+
self.permute_l = permute_l
395+
self.check_finite = check_finite
396+
self.p_indices = p_indices
397+
self.overwrite_a = overwrite_a
398+
399+
if self.permute_l:
400+
# permute_l overrides p_indices in the scipy function. We can copy that behavior
401+
self.gufunc_signature = "(m,m)->(m,m),(m,m)"
402+
elif self.p_indices:
403+
self.gufunc_signature = "(m,m)->(m),(m,m),(m,m)"
404+
else:
405+
self.gufunc_signature = "(m,m)->(m,m),(m,m),(m,m)"
406+
407+
if self.overwrite_a:
408+
self.destroy_map = {0: [0]}
409+
410+
def infer_shape(self, fgraph, node, shapes):
411+
n = shapes[0][0]
412+
if self.permute_l:
413+
return [(n, n), (n, n)]
414+
elif self.p_indices:
415+
return [(n,), (n, n), (n, n)]
416+
else:
417+
return [(n, n), (n, n), (n, n)]
418+
419+
def make_node(self, x):
420+
x = as_tensor_variable(x)
421+
if x.type.ndim != 2:
422+
raise TypeError(
423+
f"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
424+
)
425+
426+
real_dtype = "f" if np.dtype(x.type.dtype).char in "fF" else "d"
427+
p_dtype = "int32" if self.p_indices else np.dtype(real_dtype)
428+
429+
L = tensor(shape=x.type.shape, dtype=real_dtype)
430+
U = tensor(shape=x.type.shape, dtype=real_dtype)
431+
432+
if self.permute_l:
433+
# In this case, L is actually P @ L
434+
return Apply(self, inputs=[x], outputs=[L, U])
435+
elif self.p_indices:
436+
p = tensor(shape=(x.type.shape[0],), dtype=p_dtype)
437+
return Apply(self, inputs=[x], outputs=[p, L, U])
438+
else:
439+
P = tensor(shape=x.type.shape, dtype=p_dtype)
440+
return Apply(self, inputs=[x], outputs=[P, L, U])
441+
442+
def perform(self, node, inputs, outputs):
443+
[A] = inputs
444+
445+
out = scipy_linalg.lu(
446+
A,
447+
permute_l=self.permute_l,
448+
overwrite_a=self.overwrite_a,
449+
check_finite=self.check_finite,
450+
p_indices=self.p_indices,
451+
)
452+
453+
outputs[0][0] = out[0]
454+
outputs[1][0] = out[1]
455+
456+
if not self.permute_l:
457+
# In all cases except permute_l, there are three returns
458+
outputs[2][0] = out[2]
459+
460+
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
461+
if 0 in allowed_inplace_inputs:
462+
new_props = self._props_dict() # type: ignore
463+
new_props["overwrite_a"] = True
464+
return type(self)(**new_props)
465+
else:
466+
return self
467+
468+
def L_op(
469+
self,
470+
inputs: Sequence[ptb.Variable],
471+
outputs: Sequence[ptb.Variable],
472+
output_grads: Sequence[ptb.Variable],
473+
) -> list[ptb.Variable]:
474+
r"""
475+
Derivation is due to Differentiation of Matrix Functionals Using Triangular Factorization
476+
F. R. De Hoog, R.S. Anderssen, M. A. Lukas
477+
"""
478+
[A] = inputs
479+
A = cast(TensorVariable, A)
480+
481+
if self.permute_l:
482+
PL_bar, U_bar = output_grads
483+
484+
# TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient
485+
P, L, U = lu( # type: ignore
486+
A, permute_l=False, check_finite=self.check_finite, p_indices=False
487+
)
488+
489+
# Permutation matrix is orthogonal
490+
L_bar = (
491+
P.T @ PL_bar
492+
if not isinstance(PL_bar.type, DisconnectedType)
493+
else pt.zeros_like(A)
494+
)
495+
496+
elif self.p_indices:
497+
p, L, U = outputs
498+
499+
# TODO: rewrite to p_indices = False for graphs where we need to compute the gradient
500+
P = pt.eye(A.shape[0])[p]
501+
_, L_bar, U_bar = output_grads
502+
else:
503+
P, L, U = outputs
504+
_, L_bar, U_bar = output_grads
505+
506+
L_bar = (
507+
L_bar if not isinstance(L_bar.type, DisconnectedType) else pt.zeros_like(A)
508+
)
509+
U_bar = (
510+
U_bar if not isinstance(U_bar.type, DisconnectedType) else pt.zeros_like(A)
511+
)
512+
513+
x1 = ptb.tril(L.T @ L_bar, k=-1)
514+
x2 = ptb.triu(U_bar @ U.T)
515+
516+
L_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True)
517+
A_bar = P @ solve_triangular(U, L_inv_x.T, lower=False).T
518+
519+
return [A_bar]
520+
521+
522+
def lu(
523+
a: TensorLike, permute_l=False, check_finite=True, p_indices=False
524+
) -> (
525+
tuple[TensorVariable, TensorVariable, TensorVariable]
526+
| tuple[TensorVariable, TensorVariable]
527+
):
528+
"""
529+
Factorize a matrix as the product of a unit lower triangular matrix and an upper triangular matrix:
530+
531+
... math::
532+
533+
A = P L U
534+
535+
Where P is a permutation matrix, L is lower triangular with unit diagonal elements, and U is upper triangular.
536+
537+
Parameters
538+
----------
539+
a: TensorLike
540+
Matrix to be factorized
541+
permute_l: bool
542+
If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will
543+
be returned in this case, and PL will not be lower triangular.
544+
check_finite: bool
545+
Whether to check that the input matrix contains only finite numbers.
546+
p_indices: bool
547+
If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix
548+
itself.
549+
550+
Returns
551+
-------
552+
P: TensorVariable
553+
Permutation matrix, or array of integer indices for permutation matrix. Not returned if permute_l is True.
554+
L: TensorVariable
555+
Lower triangular matrix, or product of permutation and unit lower triangular matrices if permute_l is True.
556+
U: TensorVariable
557+
Upper triangular matrix
558+
"""
559+
return cast(
560+
tuple[TensorVariable, TensorVariable, TensorVariable]
561+
| tuple[TensorVariable, TensorVariable],
562+
LU(permute_l=permute_l, check_finite=check_finite, p_indices=p_indices)(a),
563+
)
564+
565+
384566
class SolveTriangular(SolveBase):
385567
"""Solve a system of linear equations."""
386568

tests/tensor/test_slinalg.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
cholesky,
2222
eigvalsh,
2323
expm,
24+
lu,
2425
solve,
2526
solve_continuous_lyapunov,
2627
solve_discrete_are,
@@ -473,6 +474,67 @@ def test_solve_dtype(self):
473474
assert x.dtype == x_result.dtype, (A_dtype, b_dtype)
474475

475476

477+
@pytest.mark.parametrize("permute_l", [True, False], ids=["permute_l", "no_permute_l"])
478+
@pytest.mark.parametrize("p_indices", [True, False], ids=["p_indices", "no_p_indices"])
479+
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
480+
def test_lu_decomposition(permute_l, p_indices, complex):
481+
dtype = config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}"
482+
A = tensor("A", shape=(None, None), dtype=dtype)
483+
out = lu(A, permute_l=permute_l, p_indices=p_indices)
484+
485+
f = pytensor.function([A], out)
486+
487+
rng = np.random.default_rng(utt.fetch_seed())
488+
x = rng.normal(size=(5, 5)).astype(config.floatX)
489+
if complex:
490+
x = x + 1j * rng.normal(size=(5, 5)).astype(config.floatX)
491+
492+
out = f(x)
493+
494+
if permute_l:
495+
PL, U = out
496+
x_rebuilt = PL @ U
497+
elif p_indices:
498+
p, L, U = out
499+
P = np.eye(5)[p]
500+
x_rebuilt = P @ L @ U
501+
else:
502+
P, L, U = out
503+
x_rebuilt = P @ L @ U
504+
505+
np.testing.assert_allclose(x, x_rebuilt)
506+
scipy_out = scipy.linalg.lu(x, permute_l=permute_l, p_indices=p_indices)
507+
508+
for a, b in zip(out, scipy_out, strict=True):
509+
np.testing.assert_allclose(a, b)
510+
511+
512+
@pytest.mark.parametrize("grad_case", [0, 1, 2], ids=["U_only", "L_only", "U_and_L"])
513+
@pytest.mark.parametrize("permute_l", [True, False])
514+
@pytest.mark.parametrize("p_indices", [True, False])
515+
def test_lu_grad(grad_case, permute_l, p_indices):
516+
rng = np.random.default_rng(utt.fetch_seed())
517+
A_value = rng.normal(size=(5, 5))
518+
519+
def f_pt(A):
520+
out = lu(A, permute_l=permute_l, p_indices=p_indices)
521+
522+
if permute_l:
523+
L, U = out
524+
else:
525+
_, L, U = out
526+
527+
match grad_case:
528+
case 0:
529+
return U.sum()
530+
case 1:
531+
return L.sum()
532+
case 2:
533+
return U.sum() + L.sum()
534+
535+
utt.verify_grad(f_pt, [A_value], rng=rng)
536+
537+
476538
def test_cho_solve():
477539
rng = np.random.default_rng(utt.fetch_seed())
478540
A = matrix()

0 commit comments

Comments
 (0)