|
21 | 21 | from pytensor.tensor.blockwise import Blockwise |
22 | 22 | from pytensor.tensor.nlinalg import kron, matrix_dot |
23 | 23 | from pytensor.tensor.shape import reshape |
24 | | -from pytensor.tensor.type import matrix, tensor, vector |
| 24 | +from pytensor.tensor.type import ivector, matrix, tensor, vector |
25 | 25 | from pytensor.tensor.variable import TensorVariable |
26 | 26 |
|
27 | 27 |
|
@@ -583,31 +583,29 @@ def lu( |
583 | 583 | ) |
584 | 584 |
|
585 | 585 |
|
586 | | -def _pivot_to_permutation(pivots): |
587 | | - """ |
588 | | - Converts a sequence of row exchanges to a permutation matrix that represents the same row exchanges. This |
589 | | - represents the inverse permutation, which can be used to reconstruct the original matrix from its LU factorization. |
590 | | - To get the actual permutation, the inverse permutation must be argsorted. |
591 | | - """ |
| 586 | +class PivotToPermutations(Op): |
| 587 | + itypes = [ivector] |
| 588 | + otypes = [ivector] |
592 | 589 |
|
593 | | - def step(i, permutation, swaps): |
594 | | - j = swaps[i] |
595 | | - x = permutation[i] |
596 | | - y = permutation[j] |
597 | | - |
598 | | - permutation = permutation[i].set(y) |
599 | | - return permutation[j].set(x) |
600 | | - |
601 | | - pivots = as_tensor_variable(pivots) |
602 | | - n = pivots.shape[0] |
603 | | - p_inv, _ = pytensor.scan( |
604 | | - step, |
605 | | - sequences=[pt.arange(n.copy())], |
606 | | - outputs_info=[pt.arange(n.copy())], |
607 | | - non_sequences=[pivots], |
608 | | - ) |
| 590 | + __props__ = () |
| 591 | + |
| 592 | + def make_node(self, pivots): |
| 593 | + pivots = as_tensor_variable(pivots) |
| 594 | + if pivots.ndim != 1: |
| 595 | + raise ValueError("PivotToPermutations only works on 1-D inputs") |
| 596 | + permutations = pivots.type() |
| 597 | + |
| 598 | + return Apply(self, [pivots], [permutations]) |
609 | 599 |
|
610 | | - return p_inv[-1] |
| 600 | + def perform(self, node, inputs, outputs): |
| 601 | + [p] = inputs |
| 602 | + p_inv = np.arange(len(p)).astype(p.dtype) |
| 603 | + for i in range(len(p)): |
| 604 | + p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i] |
| 605 | + outputs[0][0] = p_inv |
| 606 | + |
| 607 | + |
| 608 | +_pivot_to_permutation = PivotToPermutations() |
611 | 609 |
|
612 | 610 |
|
613 | 611 | class LUFactor(Op): |
@@ -810,13 +808,7 @@ def lu_solve( |
810 | 808 | ) |
811 | 809 | x = x[pt.argsort(inv_permutation)] if trans else x |
812 | 810 |
|
813 | | - return LUSolve( |
814 | | - inputs=[LU, pivots, b], |
815 | | - outputs=[x], |
816 | | - trans=trans, |
817 | | - b_ndim=b_ndim, |
818 | | - check_finite=check_finite, |
819 | | - )(LU, pivots, b) |
| 811 | + return x |
820 | 812 |
|
821 | 813 |
|
822 | 814 | class SolveTriangular(SolveBase): |
|
0 commit comments