|
1 | 1 | import math
|
| 2 | +from operator import add |
2 | 3 | from typing import TYPE_CHECKING
|
3 | 4 |
|
4 | 5 | import ndindex
|
5 | 6 | import numpy as np
|
6 |
| -from toolz import map |
| 7 | +from toolz import accumulate, map |
7 | 8 |
|
8 |
| -from cubed.core.ops import general_blockwise |
| 9 | +from cubed.backend_array_api import backend_array_to_numpy_array |
| 10 | +from cubed.core.array import CoreArray |
| 11 | +from cubed.core.ops import general_blockwise, map_selection, merge_chunks |
| 12 | +from cubed.utils import array_size |
| 13 | +from cubed.vendor.dask.array.core import normalize_chunks |
9 | 14 |
|
10 | 15 | if TYPE_CHECKING:
|
11 | 16 | from cubed.array_api.array_object import Array
|
12 | 17 |
|
13 | 18 |
|
| 19 | +def index(x, key): |
| 20 | + "Subset an array, along one or more axes." |
| 21 | + if not isinstance(key, tuple): |
| 22 | + key = (key,) |
| 23 | + |
| 24 | + # Replace Cubed arrays with NumPy arrays - note that this may trigger a computation! |
| 25 | + # Note that NumPy arrays are needed for ndindex. |
| 26 | + key = tuple( |
| 27 | + backend_array_to_numpy_array(dim_sel.compute()) |
| 28 | + if isinstance(dim_sel, CoreArray) |
| 29 | + else dim_sel |
| 30 | + for dim_sel in key |
| 31 | + ) |
| 32 | + |
| 33 | + # Canonicalize index |
| 34 | + idx = ndindex.ndindex(key) |
| 35 | + idx = idx.expand(x.shape) |
| 36 | + |
| 37 | + # Remove newaxis values, to be filled in with expand_dims at end |
| 38 | + where_newaxis = [ |
| 39 | + i for i, ia in enumerate(idx.args) if isinstance(ia, ndindex.Newaxis) |
| 40 | + ] |
| 41 | + for i, a in enumerate(where_newaxis): |
| 42 | + n = sum(isinstance(ia, ndindex.Integer) for ia in idx.args[:a]) |
| 43 | + if n: |
| 44 | + where_newaxis[i] -= n |
| 45 | + idx = ndindex.Tuple(*(ia for ia in idx.args if not isinstance(ia, ndindex.Newaxis))) |
| 46 | + selection = idx.raw |
| 47 | + |
| 48 | + # Check selection is supported |
| 49 | + if any(ia.step < 1 for ia in idx.args if isinstance(ia, ndindex.Slice)): |
| 50 | + raise NotImplementedError(f"Slice step must be >= 1: {key}") |
| 51 | + if not all( |
| 52 | + isinstance(ia, (ndindex.Integer, ndindex.Slice, ndindex.IntegerArray)) |
| 53 | + for ia in idx.args |
| 54 | + ): |
| 55 | + raise NotImplementedError( |
| 56 | + "Only integer, slice, or integer array indexes are allowed." |
| 57 | + ) |
| 58 | + if sum(1 for ia in idx.args if isinstance(ia, ndindex.IntegerArray)) > 1: |
| 59 | + raise NotImplementedError("Only one integer array index is allowed.") |
| 60 | + |
| 61 | + # Use ndindex to find the resulting array shape and chunks |
| 62 | + |
| 63 | + def chunk_len_for_indexer(ia, c): |
| 64 | + if not isinstance(ia, ndindex.Slice): |
| 65 | + return c |
| 66 | + return max(c // ia.step, 1) |
| 67 | + |
| 68 | + def merged_chunk_len_for_indexer(ia, c): |
| 69 | + if not isinstance(ia, ndindex.Slice): |
| 70 | + return c |
| 71 | + if ia.step == 1: |
| 72 | + return c |
| 73 | + if (c // ia.step) < 1: |
| 74 | + return c |
| 75 | + # note that this may not be the same as c |
| 76 | + # but it is guaranteed to be a multiple of the corresponding |
| 77 | + # value returned by chunk_len_for_indexer, which is required |
| 78 | + # by merge_chunks |
| 79 | + return (c // ia.step) * ia.step |
| 80 | + |
| 81 | + shape = idx.newshape(x.shape) |
| 82 | + |
| 83 | + if shape == x.shape: |
| 84 | + # no op case (except possibly newaxis applied below) |
| 85 | + out = x |
| 86 | + elif array_size(shape) == 0: |
| 87 | + # empty output case |
| 88 | + from cubed.array_api.creation_functions import empty |
| 89 | + |
| 90 | + out = empty(shape, dtype=x.dtype, chunks=x.chunksize, spec=x.spec) |
| 91 | + else: |
| 92 | + dtype = x.dtype |
| 93 | + chunks = tuple( |
| 94 | + chunk_len_for_indexer(ia, c) |
| 95 | + for ia, c in zip(idx.args, x.chunksize) |
| 96 | + if not isinstance(ia, ndindex.Integer) |
| 97 | + ) |
| 98 | + |
| 99 | + # this is the same as chunks, except it has the same number of dimensions as the input |
| 100 | + out_chunksizes = tuple( |
| 101 | + chunk_len_for_indexer(ia, c) if not isinstance(ia, ndindex.Integer) else 1 |
| 102 | + for ia, c in zip(idx.args, x.chunksize) |
| 103 | + ) |
| 104 | + |
| 105 | + target_chunks = normalize_chunks(chunks, shape, dtype=dtype) |
| 106 | + |
| 107 | + # use map_selection (which uses general_blockwise) to allow more opportunities for optimization than map_direct |
| 108 | + |
| 109 | + def selection_function(out_key): |
| 110 | + out_coords = out_key[1:] |
| 111 | + return _target_chunk_selection(target_chunks, out_coords, selection) |
| 112 | + |
| 113 | + max_num_input_blocks = _index_num_input_blocks( |
| 114 | + idx, x.chunksize, out_chunksizes, x.numblocks |
| 115 | + ) |
| 116 | + |
| 117 | + out = map_selection( |
| 118 | + None, # no function to apply after selection |
| 119 | + selection_function, |
| 120 | + x, |
| 121 | + shape, |
| 122 | + x.dtype, |
| 123 | + target_chunks, |
| 124 | + max_num_input_blocks=max_num_input_blocks, |
| 125 | + ) |
| 126 | + |
| 127 | + # merge chunks for any dims with step > 1 so they are |
| 128 | + # the same size as the input (or slightly smaller due to rounding) |
| 129 | + merged_chunks = tuple( |
| 130 | + merged_chunk_len_for_indexer(ia, c) |
| 131 | + for ia, c in zip(idx.args, x.chunksize) |
| 132 | + if not isinstance(ia, ndindex.Integer) |
| 133 | + ) |
| 134 | + if chunks != merged_chunks: |
| 135 | + out = merge_chunks(out, merged_chunks) |
| 136 | + |
| 137 | + for axis in where_newaxis: |
| 138 | + from cubed.array_api.manipulation_functions import expand_dims |
| 139 | + |
| 140 | + out = expand_dims(out, axis=axis) |
| 141 | + |
| 142 | + return out |
| 143 | + |
| 144 | + |
| 145 | +def _index_num_input_blocks( |
| 146 | + idx: ndindex.Tuple, in_chunksizes, out_chunksizes, numblocks |
| 147 | +): |
| 148 | + num = 1 |
| 149 | + for ia, c, oc, nb in zip(idx.args, in_chunksizes, out_chunksizes, numblocks): |
| 150 | + if isinstance(ia, ndindex.Integer) or nb == 1: |
| 151 | + pass # single block |
| 152 | + elif isinstance(ia, ndindex.Slice): |
| 153 | + if (ia.start // c) == ((ia.stop - 1) // c): |
| 154 | + pass # within same block |
| 155 | + elif ia.start % c != 0: |
| 156 | + num *= 2 # doesn't start on chunk boundary |
| 157 | + elif ia.step is not None and c % ia.step != 0 and oc > 1: |
| 158 | + # step is not a multiple of chunk size, and output chunks have more than one element |
| 159 | + # so some output chunks will access two input chunks |
| 160 | + num *= 2 |
| 161 | + elif isinstance(ia, ndindex.IntegerArray): |
| 162 | + # in the worse case, elements could be retrieved from all blocks |
| 163 | + # TODO: improve to calculate the actual max input blocks |
| 164 | + num *= nb |
| 165 | + else: |
| 166 | + raise NotImplementedError( |
| 167 | + "Only integer, slice, or int array indexes are supported." |
| 168 | + ) |
| 169 | + return num |
| 170 | + |
| 171 | + |
| 172 | +def _target_chunk_selection(target_chunks, idx, selection): |
| 173 | + # integer, integer array, and slice indexes can be interspersed in selection |
| 174 | + # idx is the chunk index for the output (target_chunks) |
| 175 | + |
| 176 | + sel = [] |
| 177 | + i = 0 # index into target_chunks and idx |
| 178 | + for s in selection: |
| 179 | + if isinstance(s, slice): |
| 180 | + offset = s.start or 0 |
| 181 | + step = s.step if s.step is not None else 1 |
| 182 | + start = tuple( |
| 183 | + accumulate(add, tuple(x * step for x in target_chunks[i]), offset) |
| 184 | + ) |
| 185 | + j = idx[i] |
| 186 | + sel.append(slice(start[j], start[j + 1], step)) |
| 187 | + i += 1 |
| 188 | + # ndindex uses np.ndarray for integer arrays |
| 189 | + elif isinstance(s, np.ndarray): |
| 190 | + # find the cumulative chunk starts |
| 191 | + target_chunk_starts = [0] + list( |
| 192 | + accumulate(add, [c for c in target_chunks[i]]) |
| 193 | + ) |
| 194 | + # and use to slice the integer array |
| 195 | + j = idx[i] |
| 196 | + sel.append(s[target_chunk_starts[j] : target_chunk_starts[j + 1]]) |
| 197 | + i += 1 |
| 198 | + elif isinstance(s, int): |
| 199 | + sel.append(s) |
| 200 | + # don't increment i since integer indexes don't have a dimension in the target |
| 201 | + else: |
| 202 | + raise ValueError(f"Unsupported selection: {s}") |
| 203 | + return tuple(sel) |
| 204 | + |
| 205 | + |
14 | 206 | class BlockView:
|
15 | 207 | """An array-like interface to the blocks of an array."""
|
16 | 208 |
|
|
0 commit comments