Skip to content

Commit 152a586

Browse files
Add LUFactor op
1 parent 70ef520 commit 152a586

File tree

3 files changed

+116
-1
lines changed

3 files changed

+116
-1
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
483483
getrf = scipy.linalg.get_lapack_funcs("getrf", (A,))
484484
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a)
485485

486-
return A_copy, ipiv
486+
return A_copy, ipiv, info
487487

488488

489489
@overload(_getrf)

pytensor/tensor/slinalg.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,103 @@ def lu(
577577
)
578578

579579

580+
class LUFactor(Op):
581+
__props__ = ("overwrite_a", "check_finite")
582+
583+
def __init__(self, *, overwrite_a=False, check_finite=True):
584+
self.overwrite_a = overwrite_a
585+
self.check_finite = check_finite
586+
self.gufunc_signature = "(m,m)->(m,m),(m)"
587+
588+
if self.overwrite_a:
589+
self.destroy_map = {0: [0]}
590+
591+
def make_node(self, A):
592+
A = as_tensor_variable(A)
593+
if A.type.ndim != 2:
594+
raise TypeError(
595+
f"LU only allowed on matrix (2-D) inputs, got {A.type.ndim}-D input"
596+
)
597+
598+
LU = matrix(shape=A.type.shape, dtype=A.type.dtype)
599+
pivots = vector(shape=(A.type.shape[0],), dtype="int32")
600+
return Apply(self, [A], [LU, pivots])
601+
602+
def infer_shape(self, fgraph, node, shapes):
603+
n = shapes[0][0]
604+
return [(n, n), (n,)]
605+
606+
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
607+
if 0 in allowed_inplace_inputs:
608+
new_props = self._props_dict() # type: ignore
609+
new_props["overwrite_a"] = True
610+
return type(self)(**new_props)
611+
else:
612+
return self
613+
614+
def perform(self, node, inputs, outputs):
615+
A = inputs[0]
616+
LU, pivots = scipy_linalg.lu_factor(
617+
A,
618+
overwrite_a=self.overwrite_a,
619+
check_finite=self.check_finite,
620+
)
621+
622+
outputs[0][0] = LU
623+
outputs[1][0] = pivots
624+
625+
def L_op(self, inputs, outputs, output_gradients):
626+
A = inputs[0]
627+
LU_bar, _ = output_gradients
628+
629+
# We need the permutation matrix P, not the pivot indices. Easiest way is to just do another LU forward.
630+
# Alternative is to do a scan over the pivot indices to convert them to permutation indices. I don't know if
631+
# that's faster or slower.
632+
P, L, U = lu(
633+
A, permute_l=False, check_finite=self.check_finite, p_indices=False
634+
)
635+
636+
# Split LU_bar into L_bar and U_bar. This is valid because of the triangular structure of L and U
637+
L_bar = ptb.tril(LU_bar, k=-1)
638+
U_bar = ptb.triu(LU_bar)
639+
640+
# From here we're in the same situation as the LU gradient derivation
641+
x1 = ptb.tril(L.T @ L_bar, k=-1)
642+
x2 = ptb.triu(U_bar @ U.T)
643+
644+
LT_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True)
645+
A_bar = P @ solve_triangular(U, LT_inv_x.T, lower=False).T
646+
647+
return [A_bar]
648+
649+
650+
def lu_factor(
651+
a: TensorLike, *, check_finite=True
652+
) -> tuple[TensorVariable, TensorVariable]:
653+
"""
654+
LU factorization with partial pivoting.
655+
656+
Parameters
657+
----------
658+
a: TensorLike
659+
Matrix to be factorized
660+
check_finite: bool
661+
Whether to check that the input matrix contains only finite numbers.
662+
663+
Returns
664+
-------
665+
LU: TensorVariable
666+
LU decomposition of `a`
667+
pivots: TensorVariable
668+
Permutation indices
669+
"""
670+
671+
return cast(
672+
tuple[TensorVariable, TensorVariable],
673+
Blockwise(LUFactor(check_finite=check_finite))(a),
674+
)
675+
676+
580677
class SolveTriangular(SolveBase):
581678
"""Solve a system of linear equations."""
582679

@@ -1448,4 +1545,5 @@ def block_diag(*matrices: TensorVariable):
14481545
"block_diag",
14491546
"cho_solve",
14501547
"lu",
1548+
"lu_factor",
14511549
]

tests/tensor/test_slinalg.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
eigvalsh,
2525
expm,
2626
lu,
27+
lu_factor,
2728
solve,
2829
solve_continuous_lyapunov,
2930
solve_discrete_are,
@@ -664,6 +665,22 @@ def f_pt(A):
664665
utt.verify_grad(f_pt, [A_value], rng=rng)
665666

666667

668+
def test_lu_factor():
669+
rng = np.random.default_rng(utt.fetch_seed())
670+
A = matrix()
671+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
672+
673+
f = pytensor.function([A], lu_factor(A))
674+
675+
LU, pivots = f(A_val)
676+
sp_LU, sp_pivots = scipy.linalg.lu_factor(A_val)
677+
678+
np.testing.assert_allclose(LU, sp_LU)
679+
np.testing.assert_allclose(pivots, sp_pivots)
680+
681+
utt.verify_grad(lambda A: lu_factor(A)[0].sum(), [A_val], rng=rng)
682+
683+
667684
def test_cho_solve():
668685
rng = np.random.default_rng(utt.fetch_seed())
669686
A = matrix()

0 commit comments

Comments
 (0)