Skip to content

Commit bfa93aa

Browse files
Fix pivot_to_permutation
1 parent 3650e99 commit bfa93aa

File tree

2 files changed

+22
-28
lines changed

2 files changed

+22
-28
lines changed

pytensor/tensor/slinalg.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -585,41 +585,30 @@ def lu(
585585

586586

587587
class PivotToPermutations(Op):
588-
__props__ = ("inverse", "inplace")
588+
__props__ = ("inverse",)
589589

590-
def __init__(self, inverse=True, inplace=False):
590+
def __init__(self, inverse=True):
591591
self.inverse = inverse
592-
self.inplace = inplace
593-
self.destroy_map = {}
594-
if self.inplace:
595-
self.destroy_map = {0: [0]}
596592

597593
def make_node(self, pivots):
598594
pivots = as_tensor_variable(pivots)
599595
if pivots.ndim != 1:
600596
raise ValueError("PivotToPermutations only works on 1-D inputs")
601-
permutations = pivots.type()
602597

598+
permutations = pivots.type.clone(dtype="int64")()
603599
return Apply(self, [pivots], [permutations])
604600

605-
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
606-
if 0 in allowed_inplace_inputs:
607-
new_props = self._props_dict() # type: ignore
608-
new_props["inplace"] = True
609-
return type(self)(**new_props)
610-
else:
611-
return self
612-
613601
def perform(self, node, inputs, outputs):
614-
[p] = inputs
615-
p_inv = np.arange(len(p)).astype(p.dtype)
616-
for i in range(len(p)):
617-
p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i]
602+
[pivots] = inputs
603+
p_inv = np.arange(len(pivots), dtype=pivots.dtype)
604+
605+
for i in range(len(pivots)):
606+
p_inv[i], p_inv[pivots[i]] = p_inv[pivots[i]], p_inv[i]
618607

619608
if self.inverse:
620609
outputs[0][0] = p_inv
621-
622-
outputs[0][0] = np.argsort(p_inv)
610+
else:
611+
outputs[0][0] = np.argsort(p_inv)
623612

624613

625614
def pivot_to_permutation(p: TensorLike, inverse=False) -> Variable:
@@ -629,14 +618,14 @@ def pivot_to_permutation(p: TensorLike, inverse=False) -> Variable:
629618

630619
class LUFactor(Op):
631620
__props__ = ("overwrite_a", "check_finite", "permutation_indices")
621+
gufunc_signature = "(m,m)->(m,m),(m)"
632622

633623
def __init__(
634624
self, *, overwrite_a=False, check_finite=True, permutation_indices=False
635625
):
636626
self.overwrite_a = overwrite_a
637627
self.check_finite = check_finite
638628
self.permutation_indices = permutation_indices
639-
self.gufunc_signature = "(m,m)->(m,m),(m)"
640629

641630
if self.overwrite_a:
642631
self.destroy_map = {1: [0]}

tests/tensor/test_slinalg.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
Solve,
1919
SolveBase,
2020
SolveTriangular,
21-
_pivot_to_permutation,
2221
block_diag,
2322
cho_solve,
2423
cholesky,
@@ -27,6 +26,7 @@
2726
lu,
2827
lu_factor,
2928
lu_solve,
29+
pivot_to_permutation,
3030
solve,
3131
solve_continuous_lyapunov,
3232
solve_discrete_are,
@@ -667,14 +667,19 @@ def f_pt(A):
667667
utt.verify_grad(f_pt, [A_value], rng=rng)
668668

669669

670-
def test_pivot_to_permutation():
670+
@pytest.mark.parametrize("inverse", [True, False], ids=["inverse", "no_inverse"])
671+
def test_pivot_to_permutation(inverse):
671672
rng = np.random.default_rng(utt.fetch_seed())
672673
A_val = rng.normal(size=(5, 5))
673674
_, pivots = scipy.linalg.lu_factor(A_val)
674675
perm_idx, *_ = scipy.linalg.lu(A_val, p_indices=True)
675676

676-
permutations = pt.argsort(_pivot_to_permutation(pivots)).eval()
677-
np.testing.assert_array_equal(permutations, perm_idx)
677+
if not inverse:
678+
perm_idx_pt = pivot_to_permutation(pivots, inverse=False).eval()
679+
np.testing.assert_array_equal(perm_idx_pt, perm_idx)
680+
else:
681+
p_inv_pt = pivot_to_permutation(pivots, inverse=True).eval()
682+
np.testing.assert_array_equal(p_inv_pt, np.argsort(perm_idx))
678683

679684

680685
class TestLUSolve(utt.InferShapeTester):
@@ -686,8 +691,8 @@ def factor_and_solve(A, b, sum=False, **lu_kwargs):
686691
return x
687692
return x.sum()
688693

689-
@pytest.mark.parametrize("b_shape", [(5,), (5, 5)])
690-
@pytest.mark.parametrize("trans", [True, False])
694+
@pytest.mark.parametrize("b_shape", [(5,), (5, 5)], ids=["b_vec", "b_matrix"])
695+
@pytest.mark.parametrize("trans", [True, False], ids=["x_T", "x"])
691696
def test_lu_solve(self, b_shape: tuple[int], trans):
692697
rng = np.random.default_rng(utt.fetch_seed())
693698
A = pt.tensor("A", shape=(5, 5))

0 commit comments

Comments
 (0)