|
27 | 27 |
|
28 | 28 | import ndindex |
29 | 29 | import numpy as np |
30 | | -from ndindex.subindex_helpers import ceiling |
31 | 30 |
|
32 | 31 | import blosc2 |
33 | 32 | from blosc2 import SpecialValue, blosc2_ext, compute_chunks_blocks |
34 | 33 | from blosc2.info import InfoReporter |
35 | 34 | from blosc2.schunk import SChunk |
36 | 35 |
|
37 | | -# NumPy version and a convenient boolean flag |
38 | | -NUMPY_GE_2_0 = np.__version__ >= "2.0" |
39 | | -# handle different numpy versions |
40 | | -if NUMPY_GE_2_0: # array-api compliant |
41 | | - nplshift = np.bitwise_left_shift |
42 | | - nprshift = np.bitwise_right_shift |
43 | | - npbinvert = np.bitwise_invert |
44 | | - npvecdot = np.vecdot |
45 | | - nptranspose = np.permute_dims |
46 | | -else: # not array-api compliant |
47 | | - nplshift = np.left_shift |
48 | | - nprshift = np.right_shift |
49 | | - npbinvert = np.bitwise_not |
50 | | - nptranspose = np.transpose |
51 | | - |
52 | | - def npvecdot(a, b, axis=-1): |
53 | | - return np.einsum("...i,...i->...", np.moveaxis(np.conj(a), axis, -1), np.moveaxis(b, axis, -1)) |
54 | | - |
| 36 | +from .linalg import matmul |
| 37 | +from .utils import ( |
| 38 | + _get_local_slice, |
| 39 | + _get_selection, |
| 40 | + get_chunks_idx, |
| 41 | + npbinvert, |
| 42 | + nplshift, |
| 43 | + nprshift, |
| 44 | + process_key, |
| 45 | + slice_to_chunktuple, |
| 46 | +) |
55 | 47 |
|
56 | 48 | # These functions in ufunc_map in ufunc_map_1param are implemented in numexpr and so we call |
57 | 49 | # those instead (since numexpr uses multithreading it is faster) |
@@ -179,15 +171,6 @@ def make_key_hashable(key): |
179 | 171 | return key |
180 | 172 |
|
181 | 173 |
|
182 | | -def process_key(key, shape): |
183 | | - key = ndindex.ndindex(key).expand(shape).raw |
184 | | - mask = tuple( |
185 | | - isinstance(k, int) for k in key |
186 | | - ) # mask to track dummy dims introduced by int -> slice(k, k+1) |
187 | | - key = tuple(slice(k, k + 1, None) if isinstance(k, int) else k for k in key) # key is slice, None, int |
188 | | - return key, mask |
189 | | - |
190 | | - |
191 | 174 | def get_ndarray_start_stop(ndim, key, shape): |
192 | 175 | # key should be Nones and slices |
193 | 176 | none_mask, start, stop, step = [], [], [], [] |
@@ -265,12 +248,6 @@ def check_contiguity(shape, part): |
265 | 248 | return check_contiguity(shape, chunks) |
266 | 249 |
|
267 | 250 |
|
268 | | -def get_chunks_idx(shape, chunks): |
269 | | - chunks_idx = tuple(math.ceil(s / c) for s, c in zip(shape, chunks, strict=True)) |
270 | | - nchunks = math.prod(chunks_idx) |
271 | | - return chunks_idx, nchunks |
272 | | - |
273 | | - |
274 | 251 | def get_flat_slices_orig(shape: tuple[int], s: tuple[slice, ...]) -> list[slice]: |
275 | 252 | """ |
276 | 253 | From array with `shape`, get the flattened list of slices corresponding to `s`. |
@@ -3074,6 +3051,7 @@ def chunkwise_logaddexp(inputs, output, offset): |
3074 | 3051 | np.logical_and: logical_and, |
3075 | 3052 | np.logical_or: logical_or, |
3076 | 3053 | np.logical_xor: logical_xor, |
| 3054 | + np.matmul: matmul, |
3077 | 3055 | } |
3078 | 3056 |
|
3079 | 3057 |
|
@@ -6320,132 +6298,6 @@ def take_along_axis(x: blosc2.Array, indices: blosc2.Array, axis: int = -1) -> N |
6320 | 6298 | return blosc2.asarray(x[key]) |
6321 | 6299 |
|
6322 | 6300 |
|
6323 | | -class MyChunkRange: |
6324 | | - def __init__(self, start, stop, step=1, n=1): |
6325 | | - self.start = start |
6326 | | - self.stop = stop |
6327 | | - self.step = step |
6328 | | - self.n = n |
6329 | | - |
6330 | | - def __iter__(self): |
6331 | | - for k in range(math.ceil((self.stop - self.start) / self.step)): |
6332 | | - yield (self.start + k * self.step) // self.n |
6333 | | - |
6334 | | - |
6335 | | -def slice_to_chunktuple(s, n): |
6336 | | - # Adapted from _slice_iter in ndindex.ChunkSize.as_subchunks. |
6337 | | - start, stop, step = s.start, s.stop, s.step |
6338 | | - if step < 0: |
6339 | | - temp = stop |
6340 | | - stop = start + 1 |
6341 | | - start = temp + 1 |
6342 | | - step = -step # get positive steps |
6343 | | - if step > n: |
6344 | | - return MyChunkRange(start, stop, step, n) |
6345 | | - else: |
6346 | | - return range(start // n, ceiling(stop, n)) |
6347 | | - |
6348 | | - |
6349 | | -def _get_selection(ctuple, ptuple, chunks): |
6350 | | - # we assume that at least one element of chunk intersects with the slice |
6351 | | - # (as a consequence of only looping over intersecting chunks) |
6352 | | - # ptuple is global slice, ctuple is chunk coords (in units of chunks) |
6353 | | - pselection = () |
6354 | | - for i, s, csize in zip(ctuple, ptuple, chunks, strict=True): |
6355 | | - # we need to advance to first element within chunk that intersects with slice, not |
6356 | | - # necessarily the first element of chunk |
6357 | | - # i * csize = s.start + n*step + k, already added n+1 elements, k in [1, step] |
6358 | | - if s.step > 0: |
6359 | | - np1 = (i * csize - s.start + s.step - 1) // s.step # gives (n + 1) |
6360 | | - # can have n = -1 if s.start > i * csize, but never < -1 since have to intersect with chunk |
6361 | | - pselection += ( |
6362 | | - slice( |
6363 | | - builtins.max( |
6364 | | - s.start, s.start + np1 * s.step |
6365 | | - ), # start+(n+1)*step gives i*csize if k=step |
6366 | | - builtins.min(csize * (i + 1), s.stop), |
6367 | | - s.step, |
6368 | | - ), |
6369 | | - ) |
6370 | | - else: |
6371 | | - # (i + 1) * csize = s.start + n*step + k, already added n+1 elements, k in [step+1, 0] |
6372 | | - np1 = ((i + 1) * csize - s.start + s.step) // s.step # gives (n + 1) |
6373 | | - # can have n = -1 if s.start < (i + 1) * csize, but never < -1 since have to intersect with chunk |
6374 | | - pselection += ( |
6375 | | - slice( |
6376 | | - builtins.min(s.start, s.start + np1 * s.step), # start+n*step gives (i+1)*csize if k=0 |
6377 | | - builtins.max(csize * i - 1, s.stop), # want to include csize * i |
6378 | | - s.step, |
6379 | | - ), |
6380 | | - ) |
6381 | | - |
6382 | | - # selection relative to coordinates of out (necessarily out_step = 1 as we work through out chunk-by-chunk of self) |
6383 | | - # when added n + 1 elements |
6384 | | - # ps.start = pt.start + step * (n+1) => n = (ps.start - pt.start - sign) // step |
6385 | | - # hence, out_start = n + 1 |
6386 | | - # ps.stop = pt.start + step * (out_stop - 1) + k, k in [step, -1] or [1, step] |
6387 | | - # => out_stop = (ps.stop - pt.start - sign) // step + 1 |
6388 | | - out_pselection = () |
6389 | | - i = 0 |
6390 | | - for ps, pt in zip(pselection, ptuple, strict=True): |
6391 | | - sign_ = np.sign(pt.step) |
6392 | | - n = (ps.start - pt.start - sign_) // pt.step |
6393 | | - out_start = n + 1 |
6394 | | - # ps.stop always positive except for case where get full array (it is then -1 since desire 0th element) |
6395 | | - out_stop = None if ps.stop == -1 else (ps.stop - pt.start - sign_) // pt.step + 1 |
6396 | | - out_pselection += ( |
6397 | | - slice( |
6398 | | - out_start, |
6399 | | - out_stop, |
6400 | | - 1, |
6401 | | - ), |
6402 | | - ) |
6403 | | - i += 1 |
6404 | | - |
6405 | | - loc_selection = tuple( # is s.stop is None, get whole chunk so s.start - 0 |
6406 | | - slice(0, s.stop - s.start, s.step) |
6407 | | - if s.step > 0 |
6408 | | - else slice(s.start if s.stop == -1 else s.start - s.stop, None, s.step) |
6409 | | - for s in pselection |
6410 | | - ) # local coords of loaded part of chunk |
6411 | | - |
6412 | | - return out_pselection, pselection, loc_selection |
6413 | | - |
6414 | | - |
6415 | | -def _get_local_slice(prior_selection, post_selection, chunk_bounds): |
6416 | | - chunk_begin, chunk_end = chunk_bounds |
6417 | | - # +1 for negative steps as have to include start (exclude stop) |
6418 | | - locbegin = np.hstack( |
6419 | | - ( |
6420 | | - [s.start if s.step > 0 else s.stop + 1 for s in prior_selection], |
6421 | | - chunk_begin, |
6422 | | - [s.start if s.step > 0 else s.stop + 1 for s in post_selection], |
6423 | | - ), |
6424 | | - casting="unsafe", |
6425 | | - dtype="int64", |
6426 | | - ) |
6427 | | - locend = np.hstack( |
6428 | | - ( |
6429 | | - [s.stop if s.step > 0 else s.start + 1 for s in prior_selection], |
6430 | | - chunk_end, |
6431 | | - [s.stop if s.step > 0 else s.start + 1 for s in post_selection], |
6432 | | - ), |
6433 | | - casting="unsafe", |
6434 | | - dtype="int64", |
6435 | | - ) |
6436 | | - return locbegin, locend |
6437 | | - |
6438 | | - |
6439 | | -def get_intersecting_chunks(_slice, shape, chunks): |
6440 | | - if 0 not in chunks: |
6441 | | - chunk_size = ndindex.ChunkSize(chunks) |
6442 | | - return chunk_size.as_subchunks(_slice, shape) # if _slice is (), returns all chunks |
6443 | | - else: |
6444 | | - return ( |
6445 | | - ndindex.ndindex(...).expand(shape), |
6446 | | - ) # chunk is whole array so just return full tuple to do loop once |
6447 | | - |
6448 | | - |
6449 | 6301 | def broadcast_to(arr: blosc2.Array, shape: tuple[int, ...]) -> NDArray: |
6450 | 6302 | """ |
6451 | 6303 | Broadcast an array to a new shape. |
|
0 commit comments