Skip to content

Commit 23d2737

Browse files
Restore LU Op
1 parent f8c9d7e commit 23d2737

File tree

1 file changed

+192
-26
lines changed

1 file changed

+192
-26
lines changed

pytensor/tensor/slinalg.py

Lines changed: 192 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
2-
import typing
32
import warnings
3+
from collections.abc import Sequence
44
from functools import reduce
55
from typing import Literal, cast
66

@@ -9,6 +9,7 @@
99

1010
import pytensor
1111
import pytensor.tensor as pt
12+
from pytensor.gradient import DisconnectedType
1213
from pytensor.graph.basic import Apply
1314
from pytensor.graph.op import Op
1415
from pytensor.tensor import TensorLike, as_tensor_variable
@@ -295,31 +296,16 @@ def L_op(self, inputs, outputs, output_gradients):
295296
# We need to return (dC/d[inv(A)], dC/db)
296297
c_bar = output_gradients[0]
297298

298-
solve_args = {k: getattr(self, k) for k in self.__props__}
299-
300-
# Some solvers can solve A.T x = b directly, without ever computing the transpose
301-
has_trans = "transposed" in self.__props__
302-
303-
if has_trans:
304-
# If the solver can do transposed solves, we do the opposite of the forward in the reverse. If we solved
305-
# C = solve(A, b), then b_bar = solve(A.T, c_bar). If we solved C = solve(A.T, b), then
306-
# b_bar = solve(A, c_bar)
307-
solve_args["transposed"] = not solve_args["transposed"]
308-
solve_op = type(self)(**solve_args)
309-
b_bar = solve_op(A, c_bar)
310-
311-
else:
312-
# Otherwise, we have to actually do the transpose of whatever was given
313-
solve_op = type(self)(**solve_args)
314-
b_bar = solve_op(A.T, c_bar)
299+
trans_solve_op = type(self)(
300+
**{
301+
k: (not getattr(self, k) if k == "lower" else getattr(self, k))
302+
for k in self.__props__
303+
}
304+
)
305+
b_bar = trans_solve_op(A.T, c_bar)
315306

316307
# force outer product if vector second input
317-
A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar @ c.T
318-
319-
if has_trans and not solve_args["transposed"]:
320-
# If we did a transposed solve in the forward pass, the program is expecting the
321-
# gradients of A.T, not A
322-
A_bar = A_bar.T
308+
A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
323309

324310
return [A_bar, b_bar]
325311

@@ -396,6 +382,186 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
396382
)(A, b)
397383

398384

385+
class LU(Op):
386+
"""Decompose a matrix into lower and upper triangular matrices."""
387+
388+
__props__ = ("permute_l", "overwrite_a", "check_finite", "p_indices")
389+
390+
def __init__(
391+
self, *, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False
392+
):
393+
self.permute_l = permute_l
394+
self.check_finite = check_finite
395+
self.p_indices = p_indices
396+
self.overwrite_a = overwrite_a
397+
398+
if self.permute_l:
399+
# permute_l overrides p_indices in the scipy function. We can copy that behavior
400+
self.gufunc_signature = "(m,m)->(m,m),(m,m)"
401+
elif self.p_indices:
402+
self.gufunc_signature = "(m,m)->(m),(m,m),(m,m)"
403+
else:
404+
self.gufunc_signature = "(m,m)->(m,m),(m,m),(m,m)"
405+
406+
if self.overwrite_a:
407+
self.destroy_map = {0: [0]}
408+
409+
def infer_shape(self, fgraph, node, shapes):
410+
n = shapes[0][0]
411+
if self.permute_l:
412+
return [(n, n), (n, n)]
413+
elif self.p_indices:
414+
return [(n,), (n, n), (n, n)]
415+
else:
416+
return [(n, n), (n, n), (n, n)]
417+
418+
def make_node(self, x):
419+
x = as_tensor_variable(x)
420+
if x.type.ndim != 2:
421+
raise TypeError(
422+
f"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
423+
)
424+
425+
real_dtype = "f" if np.dtype(x.type.dtype).char in "fF" else "d"
426+
p_dtype = "int32" if self.p_indices else np.dtype(real_dtype)
427+
428+
L = tensor(shape=x.type.shape, dtype=real_dtype)
429+
U = tensor(shape=x.type.shape, dtype=real_dtype)
430+
431+
if self.permute_l:
432+
# In this case, L is actually P @ L
433+
return Apply(self, inputs=[x], outputs=[L, U])
434+
elif self.p_indices:
435+
p = tensor(shape=(x.type.shape[0],), dtype=p_dtype)
436+
return Apply(self, inputs=[x], outputs=[p, L, U])
437+
else:
438+
P = tensor(shape=x.type.shape, dtype=p_dtype)
439+
return Apply(self, inputs=[x], outputs=[P, L, U])
440+
441+
def perform(self, node, inputs, outputs):
442+
[A] = inputs
443+
444+
out = scipy.linalg.lu(
445+
A,
446+
permute_l=self.permute_l,
447+
overwrite_a=self.overwrite_a,
448+
check_finite=self.check_finite,
449+
p_indices=self.p_indices,
450+
)
451+
452+
outputs[0][0] = out[0]
453+
outputs[1][0] = out[1]
454+
455+
if not self.permute_l:
456+
# In all cases except permute_l, there are three returns
457+
outputs[2][0] = out[2]
458+
459+
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
460+
if 0 in allowed_inplace_inputs:
461+
new_props = self._props_dict() # type: ignore
462+
new_props["overwrite_a"] = True
463+
return type(self)(**new_props)
464+
else:
465+
return self
466+
467+
def L_op(
468+
self,
469+
inputs: Sequence[ptb.Variable],
470+
outputs: Sequence[ptb.Variable],
471+
output_grads: Sequence[ptb.Variable],
472+
) -> list[ptb.Variable]:
473+
r"""
474+
Derivation is due to Differentiation of Matrix Functionals Using Triangular Factorization
475+
F. R. De Hoog, R.S. Anderssen, M. A. Lukas
476+
"""
477+
[A] = inputs
478+
A = cast(TensorVariable, A)
479+
480+
if self.permute_l:
481+
PL_bar, U_bar = output_grads
482+
483+
# TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient
484+
P, L, U = lu( # type: ignore
485+
A, permute_l=False, check_finite=self.check_finite, p_indices=False
486+
)
487+
488+
# Permutation matrix is orthogonal
489+
L_bar = (
490+
P.T @ PL_bar
491+
if not isinstance(PL_bar.type, DisconnectedType)
492+
else pt.zeros_like(A)
493+
)
494+
495+
elif self.p_indices:
496+
p, L, U = outputs
497+
498+
# TODO: rewrite to p_indices = False for graphs where we need to compute the gradient
499+
P = pt.eye(A.shape[0])[p]
500+
_, L_bar, U_bar = output_grads
501+
else:
502+
P, L, U = outputs
503+
_, L_bar, U_bar = output_grads
504+
505+
L_bar = (
506+
L_bar if not isinstance(L_bar.type, DisconnectedType) else pt.zeros_like(A)
507+
)
508+
U_bar = (
509+
U_bar if not isinstance(U_bar.type, DisconnectedType) else pt.zeros_like(A)
510+
)
511+
512+
x1 = ptb.tril(L.T @ L_bar, k=-1)
513+
x2 = ptb.triu(U_bar @ U.T)
514+
515+
L_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True)
516+
A_bar = P @ solve_triangular(U, L_inv_x.T, lower=False).T
517+
518+
return [A_bar]
519+
520+
521+
def lu(
522+
a: TensorLike, permute_l=False, check_finite=True, p_indices=False
523+
) -> (
524+
tuple[TensorVariable, TensorVariable, TensorVariable]
525+
| tuple[TensorVariable, TensorVariable]
526+
):
527+
"""
528+
Factorize a matrix as the product of a unit lower triangular matrix and an upper triangular matrix:
529+
530+
... math::
531+
532+
A = P L U
533+
534+
Where P is a permutation matrix, L is lower triangular with unit diagonal elements, and U is upper triangular.
535+
536+
Parameters
537+
----------
538+
a: TensorLike
539+
Matrix to be factorized
540+
permute_l: bool
541+
If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will
542+
be returned in this case, and PL will not be lower triangular.
543+
check_finite: bool
544+
Whether to check that the input matrix contains only finite numbers.
545+
p_indices: bool
546+
If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix
547+
itself.
548+
549+
Returns
550+
-------
551+
P: TensorVariable
552+
Permutation matrix, or array of integer indices for permutation matrix. Not returned if permute_l is True.
553+
L: TensorVariable
554+
Lower triangular matrix, or product of permutation and unit lower triangular matrices if permute_l is True.
555+
U: TensorVariable
556+
Upper triangular matrix
557+
"""
558+
return cast(
559+
tuple[TensorVariable, TensorVariable, TensorVariable]
560+
| tuple[TensorVariable, TensorVariable],
561+
LU(permute_l=permute_l, check_finite=check_finite, p_indices=p_indices)(a),
562+
)
563+
564+
399565
class SolveTriangular(SolveBase):
400566
"""Solve a system of linear equations."""
401567

@@ -513,13 +679,13 @@ class Solve(SolveBase):
513679
def __init__(self, *, assume_a="gen", transposed=False, **kwargs):
514680
if assume_a not in ("gen", "sym", "her", "pos"):
515681
raise ValueError(f"{assume_a} is not a recognized matrix structure")
682+
516683
super().__init__(**kwargs)
517684
self.assume_a = assume_a
518685
self.transposed = transposed
519686

520687
def perform(self, node, inputs, outputs):
521688
a, b = inputs
522-
523689
outputs[0][0] = scipy.linalg.solve(
524690
a=a,
525691
b=b,
@@ -1083,7 +1249,7 @@ def solve_discrete_are(
10831249
)
10841250

10851251

1086-
def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:
1252+
def _largest_common_dtype(tensors: Sequence[TensorVariable]) -> np.dtype:
10871253
return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors])
10881254

10891255

0 commit comments

Comments
 (0)