Skip to content

Commit 100e467

Browse files
Add LU decomposition Op
1 parent 4fa9bb8 commit 100e467

File tree

2 files changed

+249
-5
lines changed

2 files changed

+249
-5
lines changed

pytensor/tensor/slinalg.py

Lines changed: 187 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import logging
2-
import typing
32
import warnings
43
from functools import reduce
5-
from typing import Literal, cast
4+
from typing import Literal, cast, Sequence
65

76
import numpy as np
8-
import scipy.linalg
7+
import scipy
98

109
import pytensor
1110
import pytensor.tensor as pt
11+
from pytensor.gradient import DisconnectedType
1212
from pytensor.graph.basic import Apply
1313
from pytensor.graph.op import Op
1414
from pytensor.tensor import TensorLike, as_tensor_variable
@@ -302,6 +302,7 @@ def L_op(self, inputs, outputs, output_gradients):
302302
}
303303
)
304304
b_bar = trans_solve_op(A.T, c_bar)
305+
305306
# force outer product if vector second input
306307
A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
307308

@@ -369,7 +370,7 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
369370
Whether to check that the input matrices contain only finite numbers.
370371
Disabling may give a performance gain, but may result in problems
371372
(crashes, non-termination) if the inputs do contain infinities or NaNs.
372-
b_ndim : int
373+
b_ndim : int
373374
Whether the core case of b is a vector (1) or matrix (2).
374375
This will influence how batched dimensions are interpreted.
375376
"""
@@ -380,6 +381,186 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
380381
)(A, b)
381382

382383

384+
class LU(Op):
385+
"""Decompose a matrix into lower and upper triangular matrices."""
386+
387+
__props__ = ("permute_l", "overwrite_a", "check_finite", "p_indices")
388+
389+
def __init__(
390+
self, *, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False
391+
):
392+
self.permute_l = permute_l
393+
self.check_finite = check_finite
394+
self.p_indices = p_indices
395+
self.overwrite_a = overwrite_a
396+
397+
if self.permute_l:
398+
# permute_l overrides p_indices in the scipy function. We can copy that behavior
399+
self.gufunc_signature = "(m,m)->(m,m),(m,m)"
400+
elif self.p_indices:
401+
self.gufunc_signature = "(m,m)->(m),(m,m),(m,m)"
402+
else:
403+
self.gufunc_signature = "(m,m)->(m,m),(m,m),(m,m)"
404+
405+
if self.overwrite_a:
406+
self.destroy_map = {0: [0]}
407+
408+
def infer_shape(self, fgraph, node, shapes):
409+
n = shapes[0][0]
410+
if self.permute_l:
411+
return [(n, n), (n, n)]
412+
elif self.p_indices:
413+
return [(n,), (n, n), (n, n)]
414+
else:
415+
return [(n, n), (n, n), (n, n)]
416+
417+
def make_node(self, x):
418+
x = as_tensor_variable(x)
419+
if x.type.ndim != 2:
420+
raise TypeError(
421+
f"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
422+
)
423+
424+
real_dtype = "f" if np.dtype(x.type.dtype).char in "fF" else "d"
425+
p_dtype = "int32" if self.p_indices else np.dtype(real_dtype)
426+
427+
L = tensor(shape=x.type.shape, dtype=real_dtype)
428+
U = tensor(shape=x.type.shape, dtype=real_dtype)
429+
430+
if self.permute_l:
431+
# In this case, L is actually P @ L
432+
return Apply(self, inputs=[x], outputs=[L, U])
433+
elif self.p_indices:
434+
p = tensor(shape=(x.type.shape[0],), dtype=p_dtype)
435+
return Apply(self, inputs=[x], outputs=[p, L, U])
436+
else:
437+
P = tensor(shape=x.type.shape, dtype=p_dtype)
438+
return Apply(self, inputs=[x], outputs=[P, L, U])
439+
440+
def perform(self, node, inputs, outputs):
441+
[A] = inputs
442+
443+
out = scipy.linalg.lu(
444+
A,
445+
permute_l=self.permute_l,
446+
overwrite_a=self.overwrite_a,
447+
check_finite=self.check_finite,
448+
p_indices=self.p_indices,
449+
)
450+
451+
outputs[0][0] = out[0]
452+
outputs[1][0] = out[1]
453+
454+
if not self.permute_l:
455+
# In all cases except permute_l, there are three returns
456+
outputs[2][0] = out[2]
457+
458+
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
459+
if 0 in allowed_inplace_inputs:
460+
new_props = self._props_dict() # type: ignore
461+
new_props["overwrite_a"] = True
462+
return type(self)(**new_props)
463+
else:
464+
return self
465+
466+
def L_op(
467+
self,
468+
inputs: Sequence[ptb.Variable],
469+
outputs: Sequence[ptb.Variable],
470+
output_grads: Sequence[ptb.Variable],
471+
) -> list[ptb.Variable]:
472+
r"""
473+
Derivation is due to Differentiation of Matrix Functionals Using Triangular Factorization
474+
F. R. De Hoog, R.S. Anderssen, M. A. Lukas
475+
"""
476+
[A] = inputs
477+
A = cast(TensorVariable, A)
478+
479+
if self.permute_l:
480+
PL_bar, U_bar = output_grads
481+
482+
# TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient
483+
P, L, U = lu( # type: ignore
484+
A, permute_l=False, check_finite=self.check_finite, p_indices=False
485+
)
486+
487+
# Permutation matrix is orthogonal
488+
L_bar = (
489+
P.T @ PL_bar
490+
if not isinstance(PL_bar.type, DisconnectedType)
491+
else pt.zeros_like(A)
492+
)
493+
494+
elif self.p_indices:
495+
p, L, U = outputs
496+
497+
# TODO: rewrite to p_indices = False for graphs where we need to compute the gradient
498+
P = pt.eye(A.shape[0])[p]
499+
_, L_bar, U_bar = output_grads
500+
else:
501+
P, L, U = outputs
502+
_, L_bar, U_bar = output_grads
503+
504+
L_bar = (
505+
L_bar if not isinstance(L_bar.type, DisconnectedType) else pt.zeros_like(A)
506+
)
507+
U_bar = (
508+
U_bar if not isinstance(U_bar.type, DisconnectedType) else pt.zeros_like(A)
509+
)
510+
511+
x1 = ptb.tril(L.T @ L_bar, k=-1)
512+
x2 = ptb.triu(U_bar @ U.T)
513+
514+
L_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True)
515+
A_bar = P @ solve_triangular(U, L_inv_x.T, lower=False).T
516+
517+
return [A_bar]
518+
519+
520+
def lu(
521+
a: TensorLike, permute_l=False, check_finite=True, p_indices=False
522+
) -> (
523+
tuple[TensorVariable, TensorVariable, TensorVariable]
524+
| tuple[TensorVariable, TensorVariable]
525+
):
526+
"""
527+
Factorize a matrix as the product of a unit lower triangular matrix and an upper triangular matrix:
528+
529+
... math::
530+
531+
A = P L U
532+
533+
Where P is a permutation matrix, L is lower triangular with unit diagonal elements, and U is upper triangular.
534+
535+
Parameters
536+
----------
537+
a: TensorLike
538+
Matrix to be factorized
539+
permute_l: bool
540+
If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will
541+
be returned in this case, and PL will not be lower triangular.
542+
check_finite: bool
543+
Whether to check that the input matrix contains only finite numbers.
544+
p_indices: bool
545+
If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix
546+
itself.
547+
548+
Returns
549+
-------
550+
P: TensorVariable
551+
Permutation matrix, or array of integer indices for permutation matrix. Not returned if permute_l is True.
552+
L: TensorVariable
553+
Lower triangular matrix, or product of permutation and unit lower triangular matrices if permute_l is True.
554+
U: TensorVariable
555+
Upper triangular matrix
556+
"""
557+
return cast(
558+
tuple[TensorVariable, TensorVariable, TensorVariable]
559+
| tuple[TensorVariable, TensorVariable],
560+
LU(permute_l=permute_l, check_finite=check_finite, p_indices=p_indices)(a),
561+
)
562+
563+
383564
class SolveTriangular(SolveBase):
384565
"""Solve a system of linear equations."""
385566

@@ -1064,7 +1245,7 @@ def solve_discrete_are(
10641245
)
10651246

10661247

1067-
def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:
1248+
def _largest_common_dtype(tensors: Sequence[TensorVariable]) -> np.dtype:
10681249
return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors])
10691250

10701251

@@ -1175,4 +1356,5 @@ def block_diag(*matrices: TensorVariable):
11751356
"solve_discrete_are",
11761357
"solve_triangular",
11771358
"block_diag",
1359+
"cho_solve",
11781360
]

tests/tensor/test_slinalg.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
cholesky,
2222
eigvalsh,
2323
expm,
24+
lu,
2425
solve,
2526
solve_continuous_lyapunov,
2627
solve_discrete_are,
@@ -459,6 +460,67 @@ def test_solve_dtype(self):
459460
assert x.dtype == x_result.dtype, (A_dtype, b_dtype)
460461

461462

463+
@pytest.mark.parametrize("permute_l", [True, False], ids=["permute_l", "no_permute_l"])
464+
@pytest.mark.parametrize("p_indices", [True, False], ids=["p_indices", "no_p_indices"])
465+
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
466+
def test_lu_decomposition(permute_l, p_indices, complex):
467+
dtype = config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}"
468+
A = tensor("A", shape=(None, None), dtype=dtype)
469+
out = lu(A, permute_l=permute_l, p_indices=p_indices)
470+
471+
f = pytensor.function([A], out)
472+
473+
rng = np.random.default_rng(utt.fetch_seed())
474+
x = rng.normal(size=(5, 5)).astype(config.floatX)
475+
if complex:
476+
x = x + 1j * rng.normal(size=(5, 5)).astype(config.floatX)
477+
478+
out = f(x)
479+
480+
if permute_l:
481+
PL, U = out
482+
x_rebuilt = PL @ U
483+
elif p_indices:
484+
p, L, U = out
485+
P = np.eye(5)[p]
486+
x_rebuilt = P @ L @ U
487+
else:
488+
P, L, U = out
489+
x_rebuilt = P @ L @ U
490+
491+
np.testing.assert_allclose(x, x_rebuilt)
492+
scipy_out = scipy.linalg.lu(x, permute_l=permute_l, p_indices=p_indices)
493+
494+
for a, b in zip(out, scipy_out, strict=True):
495+
np.testing.assert_allclose(a, b)
496+
497+
498+
@pytest.mark.parametrize("grad_case", [0, 1, 2], ids=["U_only", "L_only", "U_and_L"])
499+
@pytest.mark.parametrize("permute_l", [True, False])
500+
@pytest.mark.parametrize("p_indices", [True, False])
501+
def test_lu_grad(grad_case, permute_l, p_indices):
502+
rng = np.random.default_rng(utt.fetch_seed())
503+
A_value = rng.normal(size=(5, 5))
504+
505+
def f_pt(A):
506+
out = lu(A, permute_l=permute_l, p_indices=p_indices)
507+
508+
if permute_l:
509+
L, U = out
510+
else:
511+
_, L, U = out
512+
513+
match grad_case:
514+
case 0:
515+
return U.sum()
516+
case 1:
517+
return L.sum()
518+
case 2:
519+
return U.sum() + L.sum()
520+
521+
utt.verify_grad(f_pt, [A_value], rng=rng)
522+
523+
462524
def test_cho_solve():
463525
rng = np.random.default_rng(utt.fetch_seed())
464526
A = matrix()

0 commit comments

Comments
 (0)