Skip to content

Commit cd8e0ae

Browse files
Implement specialized PivotToPermutation
1 parent 5d105d2 commit cd8e0ae

File tree

1 file changed

+23
-31
lines changed

1 file changed

+23
-31
lines changed

pytensor/tensor/slinalg.py

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pytensor.tensor.blockwise import Blockwise
2222
from pytensor.tensor.nlinalg import kron, matrix_dot
2323
from pytensor.tensor.shape import reshape
24-
from pytensor.tensor.type import matrix, tensor, vector
24+
from pytensor.tensor.type import ivector, matrix, tensor, vector
2525
from pytensor.tensor.variable import TensorVariable
2626

2727

@@ -583,31 +583,29 @@ def lu(
583583
)
584584

585585

586-
def _pivot_to_permutation(pivots):
587-
"""
588-
Converts a sequence of row exchanges to a permutation matrix that represents the same row exchanges. This
589-
represents the inverse permutation, which can be used to reconstruct the original matrix from its LU factorization.
590-
To get the actual permutation, the inverse permutation must be argsorted.
591-
"""
586+
class PivotToPermutations(Op):
587+
itypes = [ivector]
588+
otypes = [ivector]
592589

593-
def step(i, permutation, swaps):
594-
j = swaps[i]
595-
x = permutation[i]
596-
y = permutation[j]
597-
598-
permutation = permutation[i].set(y)
599-
return permutation[j].set(x)
600-
601-
pivots = as_tensor_variable(pivots)
602-
n = pivots.shape[0]
603-
p_inv, _ = pytensor.scan(
604-
step,
605-
sequences=[pt.arange(n.copy())],
606-
outputs_info=[pt.arange(n.copy())],
607-
non_sequences=[pivots],
608-
)
590+
__props__ = ()
591+
592+
def make_node(self, pivots):
593+
pivots = as_tensor_variable(pivots)
594+
if pivots.ndim != 1:
595+
raise ValueError("PivotToPermutations only works on 1-D inputs")
596+
permutations = pivots.type()
597+
598+
return Apply(self, [pivots], [permutations])
609599

610-
return p_inv[-1]
600+
def perform(self, node, inputs, outputs):
601+
[p] = inputs
602+
p_inv = np.arange(len(p)).astype(p.dtype)
603+
for i in range(len(p)):
604+
p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i]
605+
outputs[0][0] = p_inv
606+
607+
608+
_pivot_to_permutation = PivotToPermutations()
611609

612610

613611
class LUFactor(Op):
@@ -810,13 +808,7 @@ def lu_solve(
810808
)
811809
x = x[pt.argsort(inv_permutation)] if trans else x
812810

813-
return LUSolve(
814-
inputs=[LU, pivots, b],
815-
outputs=[x],
816-
trans=trans,
817-
b_ndim=b_ndim,
818-
check_finite=check_finite,
819-
)(LU, pivots, b)
811+
return x
820812

821813

822814
class SolveTriangular(SolveBase):

0 commit comments

Comments
 (0)