Skip to content

Commit 05373a9

Browse files
Expand PivotToPermutations op
1 parent 8c96856 commit 05373a9

File tree

1 file changed

+27
-10
lines changed

1 file changed

+27
-10
lines changed

pytensor/tensor/slinalg.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pytensor.tensor.blockwise import Blockwise
2222
from pytensor.tensor.nlinalg import kron, matrix_dot
2323
from pytensor.tensor.shape import reshape
24-
from pytensor.tensor.type import ivector, matrix, tensor, vector
24+
from pytensor.tensor.type import matrix, tensor, vector
2525
from pytensor.tensor.variable import TensorVariable
2626

2727

@@ -585,10 +585,14 @@ def lu(
585585

586586

587587
class PivotToPermutations(Op):
588-
itypes = [ivector]
589-
otypes = [ivector]
588+
__props__ = ("inverse", "inplace")
590589

591-
__props__ = ()
590+
def __init__(self, inverse=True, inplace=False):
591+
self.inverse = inverse
592+
self.inplace = inplace
593+
self.destroy_map = {}
594+
if self.inplace:
595+
self.destroy_map = {0: [0]}
592596

593597
def make_node(self, pivots):
594598
pivots = as_tensor_variable(pivots)
@@ -598,15 +602,29 @@ def make_node(self, pivots):
598602

599603
return Apply(self, [pivots], [permutations])
600604

605+
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
606+
if 0 in allowed_inplace_inputs:
607+
new_props = self._props_dict() # type: ignore
608+
new_props["inplace"] = True
609+
return type(self)(**new_props)
610+
else:
611+
return self
612+
601613
def perform(self, node, inputs, outputs):
602614
[p] = inputs
603615
p_inv = np.arange(len(p)).astype(p.dtype)
604616
for i in range(len(p)):
605617
p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i]
606-
outputs[0][0] = p_inv
618+
619+
if self.inverse:
620+
outputs[0][0] = p_inv
621+
622+
outputs[0][0] = np.argsort(p_inv)
607623

608624

609-
_pivot_to_permutation = PivotToPermutations()
625+
def pivot_to_permutation(p: TensorLike, inverse=False) -> Variable:
626+
p = pt.as_tensor_variable(p)
627+
return PivotToPermutations(inverse=inverse)(p)
610628

611629

612630
class LUFactor(Op):
@@ -631,7 +649,7 @@ def make_node(self, A):
631649
)
632650

633651
LU = matrix(shape=A.type.shape, dtype=A.type.dtype)
634-
pivots_or_permutations = vector(shape=(A.type.shape[0],), dtype="int32")
652+
pivots_or_permutations = vector(shape=(A.type.shape[0],), dtype="int64")
635653

636654
return Apply(self, [A], [LU, pivots_or_permutations])
637655

@@ -681,8 +699,7 @@ def L_op(self, inputs, outputs, output_gradients):
681699
U = cast(TensorVariable, ptb.triu(LU))
682700

683701
if not self.permutation_indices:
684-
p_indices_inv = _pivot_to_permutation(cast(TensorVariable, p_indices))
685-
p_indices = pt.argsort(p_indices_inv)
702+
p_indices = pivot_to_permutation(p_indices, inverse=False)
686703

687704
# Split LU_bar into L_bar and U_bar. This is valid because of the triangular structure of L and U
688705
L_bar = ptb.tril(LU_bar, k=-1)
@@ -784,7 +801,7 @@ def lu_solve(
784801
LU, pivots = LU_and_pivots
785802

786803
LU, pivots, b = map(pt.as_tensor_variable, [LU, pivots, b])
787-
inv_permutation = _pivot_to_permutation(pivots)
804+
inv_permutation = pivot_to_permutation(pivots, inverse=True)
788805

789806
x = b[inv_permutation] if not trans else b
790807

0 commit comments

Comments
 (0)