Skip to content

Commit 73cce92

Browse files
authored
Move index to indexing.py module (#744)
1 parent 3e4c2fb commit 73cce92

File tree

3 files changed

+197
-197
lines changed

3 files changed

+197
-197
lines changed

cubed/core/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def visualize(
222222
)
223223

224224
def __getitem__(self: T_ChunkedArray, key, /) -> T_ChunkedArray:
225-
from cubed.core.ops import index
225+
from cubed.core.indexing import index
226226

227227
return index(self, key)
228228

cubed/core/indexing.py

Lines changed: 194 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,208 @@
11
import math
2+
from operator import add
23
from typing import TYPE_CHECKING
34

45
import ndindex
56
import numpy as np
6-
from toolz import map
7+
from toolz import accumulate, map
78

8-
from cubed.core.ops import general_blockwise
9+
from cubed.backend_array_api import backend_array_to_numpy_array
10+
from cubed.core.array import CoreArray
11+
from cubed.core.ops import general_blockwise, map_selection, merge_chunks
12+
from cubed.utils import array_size
13+
from cubed.vendor.dask.array.core import normalize_chunks
914

1015
if TYPE_CHECKING:
1116
from cubed.array_api.array_object import Array
1217

1318

19+
def index(x, key):
20+
"Subset an array, along one or more axes."
21+
if not isinstance(key, tuple):
22+
key = (key,)
23+
24+
# Replace Cubed arrays with NumPy arrays - note that this may trigger a computation!
25+
# Note that NumPy arrays are needed for ndindex.
26+
key = tuple(
27+
backend_array_to_numpy_array(dim_sel.compute())
28+
if isinstance(dim_sel, CoreArray)
29+
else dim_sel
30+
for dim_sel in key
31+
)
32+
33+
# Canonicalize index
34+
idx = ndindex.ndindex(key)
35+
idx = idx.expand(x.shape)
36+
37+
# Remove newaxis values, to be filled in with expand_dims at end
38+
where_newaxis = [
39+
i for i, ia in enumerate(idx.args) if isinstance(ia, ndindex.Newaxis)
40+
]
41+
for i, a in enumerate(where_newaxis):
42+
n = sum(isinstance(ia, ndindex.Integer) for ia in idx.args[:a])
43+
if n:
44+
where_newaxis[i] -= n
45+
idx = ndindex.Tuple(*(ia for ia in idx.args if not isinstance(ia, ndindex.Newaxis)))
46+
selection = idx.raw
47+
48+
# Check selection is supported
49+
if any(ia.step < 1 for ia in idx.args if isinstance(ia, ndindex.Slice)):
50+
raise NotImplementedError(f"Slice step must be >= 1: {key}")
51+
if not all(
52+
isinstance(ia, (ndindex.Integer, ndindex.Slice, ndindex.IntegerArray))
53+
for ia in idx.args
54+
):
55+
raise NotImplementedError(
56+
"Only integer, slice, or integer array indexes are allowed."
57+
)
58+
if sum(1 for ia in idx.args if isinstance(ia, ndindex.IntegerArray)) > 1:
59+
raise NotImplementedError("Only one integer array index is allowed.")
60+
61+
# Use ndindex to find the resulting array shape and chunks
62+
63+
def chunk_len_for_indexer(ia, c):
64+
if not isinstance(ia, ndindex.Slice):
65+
return c
66+
return max(c // ia.step, 1)
67+
68+
def merged_chunk_len_for_indexer(ia, c):
69+
if not isinstance(ia, ndindex.Slice):
70+
return c
71+
if ia.step == 1:
72+
return c
73+
if (c // ia.step) < 1:
74+
return c
75+
# note that this may not be the same as c
76+
# but it is guaranteed to be a multiple of the corresponding
77+
# value returned by chunk_len_for_indexer, which is required
78+
# by merge_chunks
79+
return (c // ia.step) * ia.step
80+
81+
shape = idx.newshape(x.shape)
82+
83+
if shape == x.shape:
84+
# no op case (except possibly newaxis applied below)
85+
out = x
86+
elif array_size(shape) == 0:
87+
# empty output case
88+
from cubed.array_api.creation_functions import empty
89+
90+
out = empty(shape, dtype=x.dtype, chunks=x.chunksize, spec=x.spec)
91+
else:
92+
dtype = x.dtype
93+
chunks = tuple(
94+
chunk_len_for_indexer(ia, c)
95+
for ia, c in zip(idx.args, x.chunksize)
96+
if not isinstance(ia, ndindex.Integer)
97+
)
98+
99+
# this is the same as chunks, except it has the same number of dimensions as the input
100+
out_chunksizes = tuple(
101+
chunk_len_for_indexer(ia, c) if not isinstance(ia, ndindex.Integer) else 1
102+
for ia, c in zip(idx.args, x.chunksize)
103+
)
104+
105+
target_chunks = normalize_chunks(chunks, shape, dtype=dtype)
106+
107+
# use map_selection (which uses general_blockwise) to allow more opportunities for optimization than map_direct
108+
109+
def selection_function(out_key):
110+
out_coords = out_key[1:]
111+
return _target_chunk_selection(target_chunks, out_coords, selection)
112+
113+
max_num_input_blocks = _index_num_input_blocks(
114+
idx, x.chunksize, out_chunksizes, x.numblocks
115+
)
116+
117+
out = map_selection(
118+
None, # no function to apply after selection
119+
selection_function,
120+
x,
121+
shape,
122+
x.dtype,
123+
target_chunks,
124+
max_num_input_blocks=max_num_input_blocks,
125+
)
126+
127+
# merge chunks for any dims with step > 1 so they are
128+
# the same size as the input (or slightly smaller due to rounding)
129+
merged_chunks = tuple(
130+
merged_chunk_len_for_indexer(ia, c)
131+
for ia, c in zip(idx.args, x.chunksize)
132+
if not isinstance(ia, ndindex.Integer)
133+
)
134+
if chunks != merged_chunks:
135+
out = merge_chunks(out, merged_chunks)
136+
137+
for axis in where_newaxis:
138+
from cubed.array_api.manipulation_functions import expand_dims
139+
140+
out = expand_dims(out, axis=axis)
141+
142+
return out
143+
144+
145+
def _index_num_input_blocks(
146+
idx: ndindex.Tuple, in_chunksizes, out_chunksizes, numblocks
147+
):
148+
num = 1
149+
for ia, c, oc, nb in zip(idx.args, in_chunksizes, out_chunksizes, numblocks):
150+
if isinstance(ia, ndindex.Integer) or nb == 1:
151+
pass # single block
152+
elif isinstance(ia, ndindex.Slice):
153+
if (ia.start // c) == ((ia.stop - 1) // c):
154+
pass # within same block
155+
elif ia.start % c != 0:
156+
num *= 2 # doesn't start on chunk boundary
157+
elif ia.step is not None and c % ia.step != 0 and oc > 1:
158+
# step is not a multiple of chunk size, and output chunks have more than one element
159+
# so some output chunks will access two input chunks
160+
num *= 2
161+
elif isinstance(ia, ndindex.IntegerArray):
162+
# in the worse case, elements could be retrieved from all blocks
163+
# TODO: improve to calculate the actual max input blocks
164+
num *= nb
165+
else:
166+
raise NotImplementedError(
167+
"Only integer, slice, or int array indexes are supported."
168+
)
169+
return num
170+
171+
172+
def _target_chunk_selection(target_chunks, idx, selection):
173+
# integer, integer array, and slice indexes can be interspersed in selection
174+
# idx is the chunk index for the output (target_chunks)
175+
176+
sel = []
177+
i = 0 # index into target_chunks and idx
178+
for s in selection:
179+
if isinstance(s, slice):
180+
offset = s.start or 0
181+
step = s.step if s.step is not None else 1
182+
start = tuple(
183+
accumulate(add, tuple(x * step for x in target_chunks[i]), offset)
184+
)
185+
j = idx[i]
186+
sel.append(slice(start[j], start[j + 1], step))
187+
i += 1
188+
# ndindex uses np.ndarray for integer arrays
189+
elif isinstance(s, np.ndarray):
190+
# find the cumulative chunk starts
191+
target_chunk_starts = [0] + list(
192+
accumulate(add, [c for c in target_chunks[i]])
193+
)
194+
# and use to slice the integer array
195+
j = idx[i]
196+
sel.append(s[target_chunk_starts[j] : target_chunk_starts[j + 1]])
197+
i += 1
198+
elif isinstance(s, int):
199+
sel.append(s)
200+
# don't increment i since integer indexes don't have a dimension in the target
201+
else:
202+
raise ValueError(f"Unsupported selection: {s}")
203+
return tuple(sel)
204+
205+
14206
class BlockView:
15207
"""An array-like interface to the blocks of an array."""
16208

0 commit comments

Comments
 (0)