Skip to content

Commit 94d3aa3

Browse files
LUSolve Op (potentially useless)
1 parent 91815b3 commit 94d3aa3

File tree

1 file changed

+98
-1
lines changed

1 file changed

+98
-1
lines changed

pytensor/tensor/slinalg.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pytensor
1212
import pytensor.tensor as pt
1313
from pytensor.gradient import DisconnectedType
14-
from pytensor.graph.basic import Apply
14+
from pytensor.graph.basic import Apply, Variable
1515
from pytensor.graph.op import Op
1616
from pytensor.tensor import TensorLike, as_tensor_variable
1717
from pytensor.tensor import basic as ptb
@@ -733,6 +733,103 @@ def lu_factor(
733733
)
734734

735735

736+
class LUSolve(Op):
737+
"""
738+
Solve a system of linear equations given the LU factorization of the matrix.
739+
"""
740+
741+
__props__ = ("trans", "overwrite_b", "check_finite", "b_ndim")
742+
743+
def __init__(self, b_ndim, trans=False, overwrite_b=False, check_finite=True):
744+
self.trans = trans
745+
self.overwrite_b = overwrite_b
746+
self.check_finite = check_finite
747+
748+
assert b_ndim in (1, 2)
749+
self.b_ndim = b_ndim
750+
751+
if b_ndim == 1:
752+
self.gufunc_signature = "(m,m),(m),(m)->(m)"
753+
else:
754+
self.gufunc_signature = "(m,m),(m),(m,n)->(m,n)"
755+
756+
if overwrite_b:
757+
self.destroy_map = {0: [2]}
758+
759+
def make_node(self, LU, pivots, b):
760+
LU = as_tensor_variable(LU)
761+
pivots = as_tensor_variable(pivots)
762+
b = as_tensor_variable(b)
763+
764+
if LU.type.ndim != 2:
765+
raise TypeError(
766+
f"LU only allowed on matrix (2-D) inputs, got {LU.type.ndim}-D input"
767+
)
768+
769+
x = tensor(name="x", shape=b.type.shape, dtype=b.type.dtype)
770+
return Apply(self, [LU, pivots, b], [x])
771+
772+
def infer_shape(self, fgraph, node, shapes):
773+
LU_shape, pivot_shape, b_shape = shapes
774+
rows = LU_shape[1]
775+
if len(b_shape) == 1:
776+
return [(rows,)]
777+
else:
778+
cols = b_shape[1]
779+
return [(rows, cols)]
780+
781+
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
782+
if 2 in allowed_inplace_inputs:
783+
new_props = self._props_dict() # type: ignore
784+
new_props["overwrite_b"] = True
785+
return type(self)(**new_props)
786+
else:
787+
return self
788+
789+
def perform(self, node, inputs, outputs):
790+
LU, pivots, b = inputs
791+
792+
outputs[0][0] = scipy_linalg.lu_solve(
793+
lu_and_piv=(LU, pivots),
794+
b=b,
795+
check_finite=self.check_finite,
796+
trans=self.trans,
797+
overwrite_b=self.overwrite_b,
798+
)
799+
800+
def L_op(
801+
self,
802+
inputs: Sequence[Variable],
803+
outputs: Sequence[Variable],
804+
output_grads: Sequence[Variable],
805+
) -> list[Variable]:
806+
LU, pivots, b = inputs
807+
[x] = outputs
808+
[x_bar] = output_grads
809+
810+
p_inv = _pivot_to_permutation(pivots)
811+
p = pt.argsort(p_inv)
812+
P = ptb.identity_like(LU)[p]
813+
814+
# We are solving PLUx = b
815+
# Forward sensitivity will be dX = (LU)^{-1} (P.T @ db - dLU @ x)
816+
# Backward sensitivities are:
817+
# B_bar = P @ (LU)^{-T} @ X_bar
818+
# LU_bar = (-X @ X_bar.T @ (LU)^{-1}).T = -(LU)^{-T} @ X_bar @ X.T = -P.T @ B_bar @ X.T
819+
820+
# Note that (P L U)^{-T} = P (LU)^{-T} (because P is orthogonal), so we can just directly lu_solve for b_bar
821+
# with trans = not trans
822+
new_props = self._props_dict() # type: ignore
823+
new_props["trans"] = not new_props["trans"]
824+
b_bar = type(self)(**new_props)(LU, pivots, x_bar)
825+
LU_bar = -P.T @ ptm.outer(b_bar, x) if x.ndim == 1 else -P.T @ b_bar @ x.T
826+
827+
# Pivots are always disconnected; we assume they are locally stable
828+
permutations_bar = pt.zeros(pivots.shape, dtype=LU.type.dtype)
829+
830+
return [LU_bar, permutations_bar, b_bar]
831+
832+
736833
def lu_solve(
737834
LU_and_pivots: tuple[TensorVariable, TensorVariable],
738835
b: TensorVariable,

0 commit comments

Comments
 (0)