From fcd4af1c7570b08ec5d65e4a9e6ee5c209509b3a Mon Sep 17 00:00:00 2001 From: Alexander Condello Date: Thu, 4 Sep 2025 14:11:36 -0700 Subject: [PATCH] Fix the creation of Permutation symbols --- dwave/optimization/symbols.pyx | 83 ++++++++++++------- ...-symbol-construction-c80264f228458179.yaml | 5 ++ tests/test_symbols.py | 47 ++++++++++- 3 files changed, 99 insertions(+), 36 deletions(-) create mode 100644 releasenotes/notes/fix-Permutation-symbol-construction-c80264f228458179.yaml diff --git a/dwave/optimization/symbols.pyx b/dwave/optimization/symbols.pyx index 5fe43641..f2818e51 100644 --- a/dwave/optimization/symbols.pyx +++ b/dwave/optimization/symbols.pyx @@ -663,10 +663,6 @@ cdef class ArgSort(ArraySymbol): _register(ArgSort, typeid(cppArgSortNode)) -cdef bool _empty_slice(object slice_) noexcept: - return slice_.start is None and slice_.stop is None and slice_.step is None - - cdef class AdvancedIndexing(ArraySymbol): """Advanced indexing. @@ -711,35 +707,50 @@ cdef class AdvancedIndexing(ArraySymbol): array = next(self.iter_predecessors()) - if ( - isinstance(array, Constant) - and array.ndim() == 2 - and array.shape()[0] == array.shape()[1] # square matrix - and self.ptr.indices().size() == 2 - and isinstance(index, tuple) - and len(index) == 2 - ): - i0, i1 = index + if (perm := self._check_permutation(index)) is not None: + return perm - # check the [x, :][:, x] case - if (isinstance(i0, slice) and _empty_slice(i0) and - isinstance(i1, ArraySymbol) and - holds_alternative[cppArrayNodePtr](self.ptr.indices()[0]) and - get[cppArrayNodePtr](self.ptr.indices()[0]) == (i1).array_ptr and - holds_alternative[cppSlice](self.ptr.indices()[1])): + return super().__getitem__(index) - return Permutation(array, i1) + cdef object _check_permutation(self, index): + """Return a Permutation symbol if the indexing the symbol results in + a permutation, otherwise return None. + """ - # check the [:, x][x, :] case - if (isinstance(i1, slice) and _empty_slice(i1) and - isinstance(i0, ArraySymbol) and - holds_alternative[cppArrayNodePtr](self.ptr.indices()[1]) and - get[cppArrayNodePtr](self.ptr.indices()[1]) == (i0).array_ptr and - holds_alternative[cppSlice](self.ptr.indices()[0])): + # The indexed array must be a Constant square matrix + array = next(self.iter_predecessors()) + if not isinstance(array, Constant): + return None + if array.ndim() != 2 or array.shape()[0] != array.shape()[1]: + return None - return Permutation(array, i0) + # The total operation must of the form A[i0, i1][i2, i3] + + i0, i1 = self._iter_indices() + i2, i3 = index if isinstance(index, tuple) else (index, slice(None, None, None)) + + # It also must be of the form A[outer, inner][inner, outer] + + if ( + isinstance(i0, slice) and isinstance(i3, slice) # outer is a slice + and i0 == i3 and i0 == slice(None) # and those slices are empty + and isinstance(i1, ArraySymbol) and isinstance(i2, ArraySymbol) # inner is an array + and i1.id() == i2.id() # and those arrays are the same array + and i1.shape() == (array.shape()[0],) # and the shape of the array is correct + ): + return Permutation(array, i1) + + if ( + isinstance(i1, slice) and isinstance(i2, slice) # inner is a slice + and i1 == i2 and i1 == slice(None) # and those slices are empty + and isinstance(i0, ArraySymbol) and isinstance(i3, ArraySymbol) # outer is an array + and i0.id() == i3.id() # and those arrays are the same array + and i0.shape() == (array.shape()[0],) # and the shape of the array is correct + ): + return Permutation(array, i0) + + return None - return super().__getitem__(index) @classmethod def _from_symbol(cls, Symbol symbol): @@ -791,6 +802,17 @@ cdef class AdvancedIndexing(ArraySymbol): zf.writestr(directory + "indices.json", encoder.encode(indices)) + def _iter_indices(self): + for variant in self.ptr.indices(): + if holds_alternative[cppSlice](variant): + cppslice = get[cppSlice](variant) + yield slice(None) + elif holds_alternative[cppArrayNodePtr](variant): + array_ptr = get[cppArrayNodePtr](variant) + yield symbol_from_ptr(self.model, array_ptr) + else: + raise RuntimeError("unexpected variant contents") + cdef cppAdvancedIndexingNode* ptr _register(AdvancedIndexing, typeid(cppAdvancedIndexingNode)) @@ -3474,10 +3496,7 @@ cdef class Permutation(ArraySymbol): >>> type(p) """ - def __init__(self, Constant array, ListVariable x): - # todo: Loosen the types accepted. But this Cython code doesn't yet have - # the type heirarchy needed so for how we specify explicitly - + def __init__(self, Constant array, ArraySymbol x): if array.model is not x.model: raise ValueError("array and x do not share the same underlying model") diff --git a/releasenotes/notes/fix-Permutation-symbol-construction-c80264f228458179.yaml b/releasenotes/notes/fix-Permutation-symbol-construction-c80264f228458179.yaml new file mode 100644 index 00000000..701996db --- /dev/null +++ b/releasenotes/notes/fix-Permutation-symbol-construction-c80264f228458179.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Fix indexing operations sometimes raising an error when it should + create a ``Permutation`` symbol. diff --git a/tests/test_symbols.py b/tests/test_symbols.py index bedd0b23..3dee7cbb 100644 --- a/tests/test_symbols.py +++ b/tests/test_symbols.py @@ -2708,7 +2708,19 @@ def generate_symbols(self): model.lock() yield p - def test(self): + def test_constant_integer(self): + from dwave.optimization.symbols import Permutation + + model = Model() + + A = model.constant(np.arange(25).reshape((5, 5))) + x = model.constant(np.arange(5)) + + self.assertIsInstance(A[x, :][:, x], Permutation) + self.assertIsInstance(A[:, x][x, :], Permutation) + self.assertIsInstance(A[:, x][x], Permutation) + + def test_list_indexer(self): from dwave.optimization.symbols import Permutation model = Model() @@ -2718,11 +2730,38 @@ def test(self): self.assertIsInstance(A[x, :][:, x], Permutation) self.assertIsInstance(A[:, x][x, :], Permutation) + self.assertIsInstance(A[:, x][x], Permutation) + + def test_not_permutation(self): + # Some "near" permutations that aren't quite right + from dwave.optimization.symbols import Permutation + + with self.subTest("A not square"): + model = Model() + + A = model.constant(np.arange(30).reshape((5, 6))) + x = model.list(5) + + self.assertNotIsInstance(A[x, :][:, x], Permutation) + self.assertNotIsInstance(A[:, x][x, :], Permutation) + + with self.subTest("A not 2d"): + model = Model() + + A = model.constant(np.arange(25).reshape((5, 5, 1))) + x = model.list(5) + + self.assertNotIsInstance(A[x, :][:, x], Permutation) + self.assertNotIsInstance(A[:, x][x, :], Permutation) + + with self.subTest("indexer wrong size"): + model = Model() - b = model.constant(np.arange(30).reshape((5, 6))) + A = model.constant(np.arange(25).reshape((5, 5))) + x = model.list(4) - self.assertNotIsInstance(b[x, :][:, x], Permutation) - self.assertNotIsInstance(b[:, x][x, :], Permutation) + self.assertNotIsInstance(A[x, :][:, x], Permutation) + self.assertNotIsInstance(A[:, x][x, :], Permutation) class TestProd(utils.ReduceTests):