|
11 | 11 | import pytensor |
12 | 12 | import pytensor.tensor as pt |
13 | 13 | from pytensor.gradient import DisconnectedType |
14 | | -from pytensor.graph.basic import Apply |
| 14 | +from pytensor.graph.basic import Apply, Variable |
15 | 15 | from pytensor.graph.op import Op |
16 | 16 | from pytensor.tensor import TensorLike, as_tensor_variable |
17 | 17 | from pytensor.tensor import basic as ptb |
@@ -733,6 +733,103 @@ def lu_factor( |
733 | 733 | ) |
734 | 734 |
|
735 | 735 |
|
| 736 | +class LUSolve(Op): |
| 737 | + """ |
| 738 | + Solve a system of linear equations given the LU factorization of the matrix. |
| 739 | + """ |
| 740 | + |
| 741 | + __props__ = ("trans", "overwrite_b", "check_finite", "b_ndim") |
| 742 | + |
| 743 | + def __init__(self, b_ndim, trans=False, overwrite_b=False, check_finite=True): |
| 744 | + self.trans = trans |
| 745 | + self.overwrite_b = overwrite_b |
| 746 | + self.check_finite = check_finite |
| 747 | + |
| 748 | + assert b_ndim in (1, 2) |
| 749 | + self.b_ndim = b_ndim |
| 750 | + |
| 751 | + if b_ndim == 1: |
| 752 | + self.gufunc_signature = "(m,m),(m),(m)->(m)" |
| 753 | + else: |
| 754 | + self.gufunc_signature = "(m,m),(m),(m,n)->(m,n)" |
| 755 | + |
| 756 | + if overwrite_b: |
| 757 | + self.destroy_map = {0: [2]} |
| 758 | + |
| 759 | + def make_node(self, LU, pivots, b): |
| 760 | + LU = as_tensor_variable(LU) |
| 761 | + pivots = as_tensor_variable(pivots) |
| 762 | + b = as_tensor_variable(b) |
| 763 | + |
| 764 | + if LU.type.ndim != 2: |
| 765 | + raise TypeError( |
| 766 | + f"LU only allowed on matrix (2-D) inputs, got {LU.type.ndim}-D input" |
| 767 | + ) |
| 768 | + |
| 769 | + x = tensor(name="x", shape=b.type.shape, dtype=b.type.dtype) |
| 770 | + return Apply(self, [LU, pivots, b], [x]) |
| 771 | + |
| 772 | + def infer_shape(self, fgraph, node, shapes): |
| 773 | + LU_shape, pivot_shape, b_shape = shapes |
| 774 | + rows = LU_shape[1] |
| 775 | + if len(b_shape) == 1: |
| 776 | + return [(rows,)] |
| 777 | + else: |
| 778 | + cols = b_shape[1] |
| 779 | + return [(rows, cols)] |
| 780 | + |
| 781 | + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": |
| 782 | + if 2 in allowed_inplace_inputs: |
| 783 | + new_props = self._props_dict() # type: ignore |
| 784 | + new_props["overwrite_b"] = True |
| 785 | + return type(self)(**new_props) |
| 786 | + else: |
| 787 | + return self |
| 788 | + |
| 789 | + def perform(self, node, inputs, outputs): |
| 790 | + LU, pivots, b = inputs |
| 791 | + |
| 792 | + outputs[0][0] = scipy_linalg.lu_solve( |
| 793 | + lu_and_piv=(LU, pivots), |
| 794 | + b=b, |
| 795 | + check_finite=self.check_finite, |
| 796 | + trans=self.trans, |
| 797 | + overwrite_b=self.overwrite_b, |
| 798 | + ) |
| 799 | + |
| 800 | + def L_op( |
| 801 | + self, |
| 802 | + inputs: Sequence[Variable], |
| 803 | + outputs: Sequence[Variable], |
| 804 | + output_grads: Sequence[Variable], |
| 805 | + ) -> list[Variable]: |
| 806 | + LU, pivots, b = inputs |
| 807 | + [x] = outputs |
| 808 | + [x_bar] = output_grads |
| 809 | + |
| 810 | + p_inv = _pivot_to_permutation(pivots) |
| 811 | + p = pt.argsort(p_inv) |
| 812 | + P = ptb.identity_like(LU)[p] |
| 813 | + |
| 814 | + # We are solving PLUx = b |
| 815 | + # Forward sensitivity will be dX = (LU)^{-1} (P.T @ db - dLU @ x) |
| 816 | + # Backward sensitivities are: |
| 817 | + # B_bar = P @ (LU)^{-T} @ X_bar |
| 818 | + # LU_bar = (-X @ X_bar.T @ (LU)^{-1}).T = -(LU)^{-T} @ X_bar @ X.T = -P.T @ B_bar @ X.T |
| 819 | + |
| 820 | + # Note that (P L U)^{-T} = P (LU)^{-T} (because P is orthogonal), so we can just directly lu_solve for b_bar |
| 821 | + # with trans = not trans |
| 822 | + new_props = self._props_dict() # type: ignore |
| 823 | + new_props["trans"] = not new_props["trans"] |
| 824 | + b_bar = type(self)(**new_props)(LU, pivots, x_bar) |
| 825 | + LU_bar = -P.T @ ptm.outer(b_bar, x) if x.ndim == 1 else -P.T @ b_bar @ x.T |
| 826 | + |
| 827 | + # Pivots are always disconnected; we assume they are locally stable |
| 828 | + permutations_bar = pt.zeros(pivots.shape, dtype=LU.type.dtype) |
| 829 | + |
| 830 | + return [LU_bar, permutations_bar, b_bar] |
| 831 | + |
| 832 | + |
736 | 833 | def lu_solve( |
737 | 834 | LU_and_pivots: tuple[TensorVariable, TensorVariable], |
738 | 835 | b: TensorVariable, |
|
0 commit comments