Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 51 additions & 32 deletions dwave/optimization/symbols.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -663,10 +663,6 @@ cdef class ArgSort(ArraySymbol):
_register(ArgSort, typeid(cppArgSortNode))


cdef bool _empty_slice(object slice_) noexcept:
return slice_.start is None and slice_.stop is None and slice_.step is None


cdef class AdvancedIndexing(ArraySymbol):
"""Advanced indexing.

Expand Down Expand Up @@ -711,35 +707,50 @@ cdef class AdvancedIndexing(ArraySymbol):

array = next(self.iter_predecessors())

if (
isinstance(array, Constant)
and array.ndim() == 2
and array.shape()[0] == array.shape()[1] # square matrix
and self.ptr.indices().size() == 2
and isinstance(index, tuple)
and len(index) == 2
):
i0, i1 = index
if (perm := self._check_permutation(index)) is not None:
return perm

# check the [x, :][:, x] case
if (isinstance(i0, slice) and _empty_slice(i0) and
isinstance(i1, ArraySymbol) and
holds_alternative[cppArrayNodePtr](self.ptr.indices()[0]) and
get[cppArrayNodePtr](self.ptr.indices()[0]) == (<ArraySymbol>i1).array_ptr and
holds_alternative[cppSlice](self.ptr.indices()[1])):
return super().__getitem__(index)

return Permutation(array, i1)
cdef object _check_permutation(self, index):
"""Return a Permutation symbol if the indexing the symbol results in
a permutation, otherwise return None.
"""

# check the [:, x][x, :] case
if (isinstance(i1, slice) and _empty_slice(i1) and
isinstance(i0, ArraySymbol) and
holds_alternative[cppArrayNodePtr](self.ptr.indices()[1]) and
get[cppArrayNodePtr](self.ptr.indices()[1]) == (<ArraySymbol>i0).array_ptr and
holds_alternative[cppSlice](self.ptr.indices()[0])):
# The indexed array must be a Constant square matrix
array = next(self.iter_predecessors())
if not isinstance(array, Constant):
return None
if array.ndim() != 2 or array.shape()[0] != array.shape()[1]:
return None

return Permutation(array, i0)
# The total operation must of the form A[i0, i1][i2, i3]

i0, i1 = self._iter_indices()
i2, i3 = index if isinstance(index, tuple) else (index, slice(None, None, None))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming there are only two indices here breaks the case of A[:, x][(x,)] and makes a somewhat more confusing error for A[:, x][x, :, :].


# It also must be of the form A[outer, inner][inner, outer]

if (
isinstance(i0, slice) and isinstance(i3, slice) # outer is a slice
and i0 == i3 and i0 == slice(None) # and those slices are empty
and isinstance(i1, ArraySymbol) and isinstance(i2, ArraySymbol) # inner is an array
and i1.id() == i2.id() # and those arrays are the same array
and i1.shape() == (array.shape()[0],) # and the shape of the array is correct
):
return Permutation(array, i1)

if (
isinstance(i1, slice) and isinstance(i2, slice) # inner is a slice
and i1 == i2 and i1 == slice(None) # and those slices are empty
and isinstance(i0, ArraySymbol) and isinstance(i3, ArraySymbol) # outer is an array
and i0.id() == i3.id() # and those arrays are the same array
and i0.shape() == (array.shape()[0],) # and the shape of the array is correct
):
return Permutation(array, i0)

return None

return super().__getitem__(index)

@classmethod
def _from_symbol(cls, Symbol symbol):
Expand Down Expand Up @@ -791,6 +802,17 @@ cdef class AdvancedIndexing(ArraySymbol):

zf.writestr(directory + "indices.json", encoder.encode(indices))

def _iter_indices(self):
for variant in self.ptr.indices():
if holds_alternative[cppSlice](variant):
cppslice = <cppSlice>get[cppSlice](variant)
yield slice(None)
elif holds_alternative[cppArrayNodePtr](variant):
array_ptr = <cppArrayNodePtr>get[cppArrayNodePtr](variant)
yield symbol_from_ptr(self.model, array_ptr)
else:
raise RuntimeError("unexpected variant contents")

cdef cppAdvancedIndexingNode* ptr

_register(AdvancedIndexing, typeid(cppAdvancedIndexingNode))
Expand Down Expand Up @@ -3474,10 +3496,7 @@ cdef class Permutation(ArraySymbol):
>>> type(p)
<class 'dwave.optimization.symbols.Permutation'>
"""
def __init__(self, Constant array, ListVariable x):
# todo: Loosen the types accepted. But this Cython code doesn't yet have
# the type heirarchy needed so for how we specify explicitly

def __init__(self, Constant array, ArraySymbol x):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main fix, but there were a few other edge cases as well

if array.model is not x.model:
raise ValueError("array and x do not share the same underlying model")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
Fix indexing operations sometimes raising an error when it should
create a ``Permutation`` symbol.
47 changes: 43 additions & 4 deletions tests/test_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -2708,7 +2708,19 @@ def generate_symbols(self):
model.lock()
yield p

def test(self):
def test_constant_integer(self):
from dwave.optimization.symbols import Permutation

model = Model()

A = model.constant(np.arange(25).reshape((5, 5)))
x = model.constant(np.arange(5))

self.assertIsInstance(A[x, :][:, x], Permutation)
self.assertIsInstance(A[:, x][x, :], Permutation)
self.assertIsInstance(A[:, x][x], Permutation)

def test_list_indexer(self):
from dwave.optimization.symbols import Permutation

model = Model()
Expand All @@ -2718,11 +2730,38 @@ def test(self):

self.assertIsInstance(A[x, :][:, x], Permutation)
self.assertIsInstance(A[:, x][x, :], Permutation)
self.assertIsInstance(A[:, x][x], Permutation)

def test_not_permutation(self):
# Some "near" permutations that aren't quite right
from dwave.optimization.symbols import Permutation

with self.subTest("A not square"):
model = Model()

A = model.constant(np.arange(30).reshape((5, 6)))
x = model.list(5)

self.assertNotIsInstance(A[x, :][:, x], Permutation)
self.assertNotIsInstance(A[:, x][x, :], Permutation)

with self.subTest("A not 2d"):
model = Model()

A = model.constant(np.arange(25).reshape((5, 5, 1)))
x = model.list(5)

self.assertNotIsInstance(A[x, :][:, x], Permutation)
self.assertNotIsInstance(A[:, x][x, :], Permutation)

with self.subTest("indexer wrong size"):
model = Model()

b = model.constant(np.arange(30).reshape((5, 6)))
A = model.constant(np.arange(25).reshape((5, 5)))
x = model.list(4)

self.assertNotIsInstance(b[x, :][:, x], Permutation)
self.assertNotIsInstance(b[:, x][x, :], Permutation)
self.assertNotIsInstance(A[x, :][:, x], Permutation)
self.assertNotIsInstance(A[:, x][x, :], Permutation)


class TestProd(utils.ReduceTests):
Expand Down