Skip to content

Commit 4e6944d

Browse files
Add LUFactor op
1 parent 211a522 commit 4e6944d

File tree

3 files changed

+116
-1
lines changed

3 files changed

+116
-1
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
483483
getrf = scipy.linalg.get_lapack_funcs("getrf", (A,))
484484
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a)
485485

486-
return A_copy, ipiv
486+
return A_copy, ipiv, info
487487

488488

489489
@overload(_getrf)

pytensor/tensor/slinalg.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,103 @@ def lu(
565565
)
566566

567567

568+
class LUFactor(Op):
569+
__props__ = ("overwrite_a", "check_finite")
570+
571+
def __init__(self, *, overwrite_a=False, check_finite=True):
572+
self.overwrite_a = overwrite_a
573+
self.check_finite = check_finite
574+
self.gufunc_signature = "(m,m)->(m,m),(m)"
575+
576+
if self.overwrite_a:
577+
self.destroy_map = {0: [0]}
578+
579+
def make_node(self, A):
580+
A = as_tensor_variable(A)
581+
if A.type.ndim != 2:
582+
raise TypeError(
583+
f"LU only allowed on matrix (2-D) inputs, got {A.type.ndim}-D input"
584+
)
585+
586+
LU = matrix(shape=A.type.shape, dtype=A.type.dtype)
587+
pivots = vector(shape=(A.type.shape[0],), dtype="int32")
588+
return Apply(self, [A], [LU, pivots])
589+
590+
def infer_shape(self, fgraph, node, shapes):
591+
n = shapes[0][0]
592+
return [(n, n), (n,)]
593+
594+
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
595+
if 0 in allowed_inplace_inputs:
596+
new_props = self._props_dict() # type: ignore
597+
new_props["overwrite_a"] = True
598+
return type(self)(**new_props)
599+
else:
600+
return self
601+
602+
def perform(self, node, inputs, outputs):
603+
A = inputs[0]
604+
LU, pivots = scipy_linalg.lu_factor(
605+
A,
606+
overwrite_a=self.overwrite_a,
607+
check_finite=self.check_finite,
608+
)
609+
610+
outputs[0][0] = LU
611+
outputs[1][0] = pivots
612+
613+
def L_op(self, inputs, outputs, output_gradients):
614+
A = inputs[0]
615+
LU_bar, _ = output_gradients
616+
617+
# We need the permutation matrix P, not the pivot indices. Easiest way is to just do another LU forward.
618+
# Alternative is to do a scan over the pivot indices to convert them to permutation indices. I don't know if
619+
# that's faster or slower.
620+
P, L, U = lu(
621+
A, permute_l=False, check_finite=self.check_finite, p_indices=False
622+
)
623+
624+
# Split LU_bar into L_bar and U_bar. This is valid because of the triangular structure of L and U
625+
L_bar = ptb.tril(LU_bar, k=-1)
626+
U_bar = ptb.triu(LU_bar)
627+
628+
# From here we're in the same situation as the LU gradient derivation
629+
x1 = ptb.tril(L.T @ L_bar, k=-1)
630+
x2 = ptb.triu(U_bar @ U.T)
631+
632+
LT_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True)
633+
A_bar = P @ solve_triangular(U, LT_inv_x.T, lower=False).T
634+
635+
return [A_bar]
636+
637+
638+
def lu_factor(
639+
a: TensorLike, *, check_finite=True
640+
) -> tuple[TensorVariable, TensorVariable]:
641+
"""
642+
LU factorization with partial pivoting.
643+
644+
Parameters
645+
----------
646+
a: TensorLike
647+
Matrix to be factorized
648+
check_finite: bool
649+
Whether to check that the input matrix contains only finite numbers.
650+
651+
Returns
652+
-------
653+
LU: TensorVariable
654+
LU decomposition of `a`
655+
pivots: TensorVariable
656+
Permutation indices
657+
"""
658+
659+
return cast(
660+
tuple[TensorVariable, TensorVariable],
661+
Blockwise(LUFactor(check_finite=check_finite))(a),
662+
)
663+
664+
568665
class SolveTriangular(SolveBase):
569666
"""Solve a system of linear equations."""
570667

@@ -1362,4 +1459,5 @@ def block_diag(*matrices: TensorVariable):
13621459
"block_diag",
13631460
"cho_solve",
13641461
"lu",
1462+
"lu_factor",
13651463
]

tests/tensor/test_slinalg.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
eigvalsh,
2323
expm,
2424
lu,
25+
lu_factor,
2526
solve,
2627
solve_continuous_lyapunov,
2728
solve_discrete_are,
@@ -553,6 +554,22 @@ def f_pt(A):
553554
utt.verify_grad(f_pt, [A_value], rng=rng)
554555

555556

557+
def test_lu_factor():
558+
rng = np.random.default_rng(utt.fetch_seed())
559+
A = matrix()
560+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
561+
562+
f = pytensor.function([A], lu_factor(A))
563+
564+
LU, pivots = f(A_val)
565+
sp_LU, sp_pivots = scipy.linalg.lu_factor(A_val)
566+
567+
np.testing.assert_allclose(LU, sp_LU)
568+
np.testing.assert_allclose(pivots, sp_pivots)
569+
570+
utt.verify_grad(lambda A: lu_factor(A)[0].sum(), [A_val], rng=rng)
571+
572+
556573
def test_cho_solve():
557574
rng = np.random.default_rng(utt.fetch_seed())
558575
A = matrix()

0 commit comments

Comments
 (0)