Skip to content

Commit 7a34314

Browse files
authored
Merge pull request #529 from Blosc/matmul
Refactor to allow for matmul ufunc
2 parents a719fdd + 0026ad8 commit 7a34314

File tree

9 files changed

+194
-171
lines changed

9 files changed

+194
-171
lines changed

src/blosc2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def _raise(exc):
451451
from .schunk import SChunk, open
452452
from . import linalg
453453
from .linalg import tensordot, vecdot, permute_dims, matrix_transpose, matmul, transpose, diagonal, outer
454-
from .shape_utils import linalg_funcs as linalg_funcs_list
454+
from .utils import linalg_funcs as linalg_funcs_list
455455
from . import fft
456456

457457
# Registry for postfilters

src/blosc2/lazyexpr.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,22 @@
4040
import blosc2
4141
from blosc2 import compute_chunks_blocks
4242
from blosc2.info import InfoReporter
43-
from blosc2.ndarray import (
43+
44+
from .proxy import _convert_dtype
45+
from .utils import (
4446
NUMPY_GE_2_0,
47+
constructors,
48+
elementwise_funcs,
4549
get_chunks_idx,
4650
get_intersecting_chunks,
51+
infer_shape,
52+
linalg_attrs,
53+
linalg_funcs,
54+
npvecdot,
4755
process_key,
56+
reducers,
4857
)
4958

50-
from .proxy import _convert_dtype
51-
from .shape_utils import constructors, elementwise_funcs, infer_shape, linalg_attrs, linalg_funcs, reducers
52-
5359
if not blosc2.IS_WASM:
5460
import numexpr
5561

@@ -78,7 +84,7 @@
7884
safe_numpy_globals["bitwise_invert"] = np.bitwise_not
7985
safe_numpy_globals["concat"] = np.concatenate
8086
safe_numpy_globals["matrix_transpose"] = np.transpose
81-
safe_numpy_globals["vecdot"] = blosc2.ndarray.npvecdot
87+
safe_numpy_globals["vecdot"] = npvecdot
8288

8389

8490
def ne_evaluate(expression, local_dict=None, **kwargs):

src/blosc2/linalg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import numpy as np
1010

1111
import blosc2
12-
from blosc2.ndarray import get_intersecting_chunks, nptranspose, npvecdot, slice_to_chunktuple
12+
13+
from .utils import get_intersecting_chunks, nptranspose, npvecdot, slice_to_chunktuple
1314

1415
if TYPE_CHECKING:
1516
from collections.abc import Sequence

src/blosc2/ndarray.py

Lines changed: 12 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,23 @@
2727

2828
import ndindex
2929
import numpy as np
30-
from ndindex.subindex_helpers import ceiling
3130

3231
import blosc2
3332
from blosc2 import SpecialValue, blosc2_ext, compute_chunks_blocks
3433
from blosc2.info import InfoReporter
3534
from blosc2.schunk import SChunk
3635

37-
# NumPy version and a convenient boolean flag
38-
NUMPY_GE_2_0 = np.__version__ >= "2.0"
39-
# handle different numpy versions
40-
if NUMPY_GE_2_0: # array-api compliant
41-
nplshift = np.bitwise_left_shift
42-
nprshift = np.bitwise_right_shift
43-
npbinvert = np.bitwise_invert
44-
npvecdot = np.vecdot
45-
nptranspose = np.permute_dims
46-
else: # not array-api compliant
47-
nplshift = np.left_shift
48-
nprshift = np.right_shift
49-
npbinvert = np.bitwise_not
50-
nptranspose = np.transpose
51-
52-
def npvecdot(a, b, axis=-1):
53-
return np.einsum("...i,...i->...", np.moveaxis(np.conj(a), axis, -1), np.moveaxis(b, axis, -1))
54-
36+
from .linalg import matmul
37+
from .utils import (
38+
_get_local_slice,
39+
_get_selection,
40+
get_chunks_idx,
41+
npbinvert,
42+
nplshift,
43+
nprshift,
44+
process_key,
45+
slice_to_chunktuple,
46+
)
5547

5648
# These functions in ufunc_map in ufunc_map_1param are implemented in numexpr and so we call
5749
# those instead (since numexpr uses multithreading it is faster)
@@ -179,15 +171,6 @@ def make_key_hashable(key):
179171
return key
180172

181173

182-
def process_key(key, shape):
183-
key = ndindex.ndindex(key).expand(shape).raw
184-
mask = tuple(
185-
isinstance(k, int) for k in key
186-
) # mask to track dummy dims introduced by int -> slice(k, k+1)
187-
key = tuple(slice(k, k + 1, None) if isinstance(k, int) else k for k in key) # key is slice, None, int
188-
return key, mask
189-
190-
191174
def get_ndarray_start_stop(ndim, key, shape):
192175
# key should be Nones and slices
193176
none_mask, start, stop, step = [], [], [], []
@@ -265,12 +248,6 @@ def check_contiguity(shape, part):
265248
return check_contiguity(shape, chunks)
266249

267250

268-
def get_chunks_idx(shape, chunks):
269-
chunks_idx = tuple(math.ceil(s / c) for s, c in zip(shape, chunks, strict=True))
270-
nchunks = math.prod(chunks_idx)
271-
return chunks_idx, nchunks
272-
273-
274251
def get_flat_slices_orig(shape: tuple[int], s: tuple[slice, ...]) -> list[slice]:
275252
"""
276253
From array with `shape`, get the flattened list of slices corresponding to `s`.
@@ -3074,6 +3051,7 @@ def chunkwise_logaddexp(inputs, output, offset):
30743051
np.logical_and: logical_and,
30753052
np.logical_or: logical_or,
30763053
np.logical_xor: logical_xor,
3054+
np.matmul: matmul,
30773055
}
30783056

30793057

@@ -6320,132 +6298,6 @@ def take_along_axis(x: blosc2.Array, indices: blosc2.Array, axis: int = -1) -> N
63206298
return blosc2.asarray(x[key])
63216299

63226300

6323-
class MyChunkRange:
6324-
def __init__(self, start, stop, step=1, n=1):
6325-
self.start = start
6326-
self.stop = stop
6327-
self.step = step
6328-
self.n = n
6329-
6330-
def __iter__(self):
6331-
for k in range(math.ceil((self.stop - self.start) / self.step)):
6332-
yield (self.start + k * self.step) // self.n
6333-
6334-
6335-
def slice_to_chunktuple(s, n):
6336-
# Adapted from _slice_iter in ndindex.ChunkSize.as_subchunks.
6337-
start, stop, step = s.start, s.stop, s.step
6338-
if step < 0:
6339-
temp = stop
6340-
stop = start + 1
6341-
start = temp + 1
6342-
step = -step # get positive steps
6343-
if step > n:
6344-
return MyChunkRange(start, stop, step, n)
6345-
else:
6346-
return range(start // n, ceiling(stop, n))
6347-
6348-
6349-
def _get_selection(ctuple, ptuple, chunks):
6350-
# we assume that at least one element of chunk intersects with the slice
6351-
# (as a consequence of only looping over intersecting chunks)
6352-
# ptuple is global slice, ctuple is chunk coords (in units of chunks)
6353-
pselection = ()
6354-
for i, s, csize in zip(ctuple, ptuple, chunks, strict=True):
6355-
# we need to advance to first element within chunk that intersects with slice, not
6356-
# necessarily the first element of chunk
6357-
# i * csize = s.start + n*step + k, already added n+1 elements, k in [1, step]
6358-
if s.step > 0:
6359-
np1 = (i * csize - s.start + s.step - 1) // s.step # gives (n + 1)
6360-
# can have n = -1 if s.start > i * csize, but never < -1 since have to intersect with chunk
6361-
pselection += (
6362-
slice(
6363-
builtins.max(
6364-
s.start, s.start + np1 * s.step
6365-
), # start+(n+1)*step gives i*csize if k=step
6366-
builtins.min(csize * (i + 1), s.stop),
6367-
s.step,
6368-
),
6369-
)
6370-
else:
6371-
# (i + 1) * csize = s.start + n*step + k, already added n+1 elements, k in [step+1, 0]
6372-
np1 = ((i + 1) * csize - s.start + s.step) // s.step # gives (n + 1)
6373-
# can have n = -1 if s.start < (i + 1) * csize, but never < -1 since have to intersect with chunk
6374-
pselection += (
6375-
slice(
6376-
builtins.min(s.start, s.start + np1 * s.step), # start+n*step gives (i+1)*csize if k=0
6377-
builtins.max(csize * i - 1, s.stop), # want to include csize * i
6378-
s.step,
6379-
),
6380-
)
6381-
6382-
# selection relative to coordinates of out (necessarily out_step = 1 as we work through out chunk-by-chunk of self)
6383-
# when added n + 1 elements
6384-
# ps.start = pt.start + step * (n+1) => n = (ps.start - pt.start - sign) // step
6385-
# hence, out_start = n + 1
6386-
# ps.stop = pt.start + step * (out_stop - 1) + k, k in [step, -1] or [1, step]
6387-
# => out_stop = (ps.stop - pt.start - sign) // step + 1
6388-
out_pselection = ()
6389-
i = 0
6390-
for ps, pt in zip(pselection, ptuple, strict=True):
6391-
sign_ = np.sign(pt.step)
6392-
n = (ps.start - pt.start - sign_) // pt.step
6393-
out_start = n + 1
6394-
# ps.stop always positive except for case where get full array (it is then -1 since desire 0th element)
6395-
out_stop = None if ps.stop == -1 else (ps.stop - pt.start - sign_) // pt.step + 1
6396-
out_pselection += (
6397-
slice(
6398-
out_start,
6399-
out_stop,
6400-
1,
6401-
),
6402-
)
6403-
i += 1
6404-
6405-
loc_selection = tuple( # is s.stop is None, get whole chunk so s.start - 0
6406-
slice(0, s.stop - s.start, s.step)
6407-
if s.step > 0
6408-
else slice(s.start if s.stop == -1 else s.start - s.stop, None, s.step)
6409-
for s in pselection
6410-
) # local coords of loaded part of chunk
6411-
6412-
return out_pselection, pselection, loc_selection
6413-
6414-
6415-
def _get_local_slice(prior_selection, post_selection, chunk_bounds):
6416-
chunk_begin, chunk_end = chunk_bounds
6417-
# +1 for negative steps as have to include start (exclude stop)
6418-
locbegin = np.hstack(
6419-
(
6420-
[s.start if s.step > 0 else s.stop + 1 for s in prior_selection],
6421-
chunk_begin,
6422-
[s.start if s.step > 0 else s.stop + 1 for s in post_selection],
6423-
),
6424-
casting="unsafe",
6425-
dtype="int64",
6426-
)
6427-
locend = np.hstack(
6428-
(
6429-
[s.stop if s.step > 0 else s.start + 1 for s in prior_selection],
6430-
chunk_end,
6431-
[s.stop if s.step > 0 else s.start + 1 for s in post_selection],
6432-
),
6433-
casting="unsafe",
6434-
dtype="int64",
6435-
)
6436-
return locbegin, locend
6437-
6438-
6439-
def get_intersecting_chunks(_slice, shape, chunks):
6440-
if 0 not in chunks:
6441-
chunk_size = ndindex.ChunkSize(chunks)
6442-
return chunk_size.as_subchunks(_slice, shape) # if _slice is (), returns all chunks
6443-
else:
6444-
return (
6445-
ndindex.ndindex(...).expand(shape),
6446-
) # chunk is whole array so just return full tuple to do loop once
6447-
6448-
64496301
def broadcast_to(arr: blosc2.Array, shape: tuple[int, ...]) -> NDArray:
64506302
"""
64516303
Broadcast an array to a new shape.

0 commit comments

Comments
 (0)