Skip to content

Commit 78cb864

Browse files
committed
Fixed compatibility with numpy 1.26
1 parent 38c6bd3 commit 78cb864

File tree

7 files changed

+236
-104
lines changed

7 files changed

+236
-104
lines changed

src/blosc2/lazyexpr.py

Lines changed: 162 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from blosc2 import compute_chunks_blocks
4242
from blosc2.info import InfoReporter
4343
from blosc2.ndarray import (
44+
NUMPY_GE_2_0,
4445
_check_allowed_dtypes,
4546
get_chunks_idx,
4647
get_intersecting_chunks,
@@ -53,11 +54,36 @@
5354

5455
from .shape_utils import constructors, infer_shape, lin_alg_funcs, reducers
5556

56-
not_numexpr_funcs = lin_alg_funcs + ("clip", "logaddexp")
57-
5857
if not blosc2.IS_WASM:
5958
import numexpr
6059

60+
global safe_blosc2_globals
61+
safe_blosc2_globals = {}
62+
global safe_numpy_globals
63+
# Use numpy eval when running in WebAssembly
64+
safe_numpy_globals = {"np": np}
65+
# Add all first-level numpy functions
66+
safe_numpy_globals.update(
67+
{name: getattr(np, name) for name in dir(np) if callable(getattr(np, name)) and not name.startswith("_")}
68+
)
69+
70+
if not NUMPY_GE_2_0: # handle non-array-api compliance
71+
safe_numpy_globals["acos"] = np.arccos
72+
safe_numpy_globals["acosh"] = np.arccosh
73+
safe_numpy_globals["asin"] = np.arcsin
74+
safe_numpy_globals["asinh"] = np.arcsinh
75+
safe_numpy_globals["atan"] = np.arctan
76+
safe_numpy_globals["atanh"] = np.arctanh
77+
safe_numpy_globals["atan2"] = np.arctan2
78+
safe_numpy_globals["permute_dims"] = np.transpose
79+
safe_numpy_globals["pow"] = np.power
80+
safe_numpy_globals["bitwise_left_shift"] = np.left_shift
81+
safe_numpy_globals["bitwise_right_shift"] = np.right_shift
82+
safe_numpy_globals["bitwise_invert"] = np.bitwise_not
83+
safe_numpy_globals["concat"] = np.concatenate
84+
safe_numpy_globals["matrix_transpose"] = np.transpose
85+
safe_numpy_globals["vecdot"] = blosc2.ndarray.npvecdot
86+
6187

6288
def ne_evaluate(expression, local_dict=None, **kwargs):
6389
"""Safely evaluate expressions using numexpr when possible, falling back to numpy."""
@@ -76,22 +102,24 @@ def ne_evaluate(expression, local_dict=None, **kwargs):
76102
)
77103
}
78104
if blosc2.IS_WASM:
79-
# Use numpy eval when running in WebAssembly
80-
safe_globals = {"np": np}
81-
# Add all first-level numpy functions
82-
safe_globals.update(
83-
{
84-
name: getattr(np, name)
85-
for name in dir(np)
86-
if callable(getattr(np, name)) and not name.startswith("_")
87-
}
88-
)
105+
global safe_numpy_globals
106+
if "out" in kwargs:
107+
out = kwargs.pop("out")
108+
out[:] = eval(expression, safe_numpy_globals, local_dict)
109+
return out
110+
return eval(expression, safe_numpy_globals, local_dict)
111+
try:
112+
return numexpr.evaluate(expression, local_dict=local_dict, **kwargs)
113+
except ValueError as e:
114+
raise e # unsafe expression
115+
except Exception: # non_numexpr functions present
116+
global safe_blosc2_globals
117+
res = eval(expression, safe_blosc2_globals, local_dict)
89118
if "out" in kwargs:
90119
out = kwargs.pop("out")
91-
out[:] = eval(expression, safe_globals, local_dict)
120+
out[:] = res[()] if isinstance(res, blosc2.LazyArray) else res
92121
return out
93-
return eval(expression, safe_globals, local_dict)
94-
return numexpr.evaluate(expression, local_dict=local_dict, **kwargs)
122+
return res[()] if isinstance(res, blosc2.LazyArray) else res
95123

96124

97125
# Define empty ndindex tuple for function defaults
@@ -131,56 +159,116 @@ def ne_evaluate(expression, local_dict=None, **kwargs):
131159
"V": np.bytes_,
132160
}
133161

134-
functions = [
135-
"sin",
136-
"cos",
137-
"tan",
138-
"sqrt",
139-
"sinh",
140-
"cosh",
141-
"tanh",
142-
"arcsin",
162+
blosc2_funcs = [
163+
"abs",
164+
"acos",
165+
"acosh",
166+
"add",
167+
"all",
168+
"any",
169+
"arange",
143170
"arccos",
171+
"arccosh",
172+
"arcsin",
173+
"arcsinh",
144174
"arctan",
145175
"arctan2",
146-
"arcsinh",
147-
"arccosh",
148176
"arctanh",
149-
"exp",
177+
"asin",
178+
"asinh",
179+
"atan",
180+
"atan2",
181+
"atanh",
182+
"bitwise_and",
183+
"bitwise_invert",
184+
"bitwise_left_shift",
185+
"bitwise_or",
186+
"bitwise_right_shift",
187+
"bitwise_xor",
188+
"broadcast_to",
189+
"ceil",
190+
"clip",
191+
"concat",
192+
"concatenate",
193+
"copy",
194+
"copysign",
195+
"count_nonzero",
196+
"divide",
197+
"empty",
198+
"empty_like",
199+
"equal",
200+
"expand_dims",
150201
"expm1",
202+
"eye",
203+
"floor",
204+
"floor_divide",
205+
"frombuffer",
206+
"fromiter",
207+
"full",
208+
"full_like",
209+
"greater",
210+
"greater_equal",
211+
"hypot",
212+
"isfinite",
213+
"isinf",
214+
"isnan",
215+
"less_equal",
216+
"less_than",
217+
"linspace",
151218
"log",
152-
"log10",
153219
"log1p",
154220
"log2",
155-
"conj",
156-
"real",
157-
"imag",
158-
"contains",
159-
"abs",
160-
"sum",
161-
"prod",
162-
"mean",
163-
"std",
164-
"var",
165-
"min",
221+
"log10",
222+
"logaddexp",
223+
"logical_and",
224+
"logical_not",
225+
"logical_or",
226+
"logical_xor",
227+
"matmul",
228+
"matrix_transpose",
166229
"max",
167-
"any",
168-
"all",
169-
"pow" if np.__version__.startswith("2.") else "power",
170-
"where",
171-
"isnan",
172-
"isfinite",
173-
"isinf",
174-
"nextafter",
175-
"copysign",
176-
"hypot",
177230
"maximum",
231+
"mean",
232+
"meshgrid",
233+
"min",
178234
"minimum",
179-
"floor",
180-
"ceil",
181-
"trunc",
182-
"signbit",
235+
"multiply",
236+
"nans",
237+
"ndarray_from_cframe",
238+
"negative",
239+
"nextafter",
240+
"not_equal",
241+
"ones",
242+
"ones_like",
243+
"permute_dims",
244+
"positive",
245+
"pow",
246+
"prod",
247+
"real",
248+
"reciprocal",
249+
"remainder",
250+
"reshape",
183251
"round",
252+
"sign",
253+
"signbit",
254+
"sort",
255+
"square",
256+
"squeeze",
257+
"stack",
258+
"sum",
259+
"subtract",
260+
"take",
261+
"take_along_axis",
262+
"tan",
263+
"tanh",
264+
"tensordot",
265+
"transpose",
266+
"trunc",
267+
"var",
268+
"vecdot",
269+
"where",
270+
"zeros",
271+
"zeros_like",
184272
]
185273

186274
# Gather all callable functions in numpy
@@ -192,10 +280,8 @@ def ne_evaluate(expression, local_dict=None, **kwargs):
192280
numpy_ufuncs = {name for name, member in inspect.getmembers(np, lambda x: isinstance(x, np.ufunc))}
193281
# Add these functions to the list of available functions
194282
# (will be evaluated via the array interface)
195-
additional_funcs = sorted((numpy_funcs | numpy_ufuncs) - set(functions))
196-
functions += additional_funcs
197-
198-
functions += constructors
283+
additional_funcs = sorted((numpy_funcs | numpy_ufuncs) - set(blosc2_funcs))
284+
functions = blosc2_funcs + additional_funcs
199285

200286
relational_ops = ["==", "!=", "<", "<=", ">", ">="]
201287
logical_ops = ["&", "|", "^", "~"]
@@ -204,8 +290,7 @@ def ne_evaluate(expression, local_dict=None, **kwargs):
204290

205291
def get_expr_globals(expression):
206292
"""Build a dictionary of functions needed for evaluating the expression."""
207-
_globals = {}
208-
293+
_globals = {"np": np, "blosc2": blosc2}
209294
# Only check for functions that actually appear in the expression
210295
# This avoids many unnecessary string searches
211296
for func in functions:
@@ -2922,8 +3007,23 @@ def find_args(expr):
29223007

29233008
return value, expression[idx:idx2]
29243009

2925-
def _compute_expr(self, item, kwargs):
2926-
if any(method in self.expression for method in reducers + not_numexpr_funcs):
3010+
def _compute_expr(self, item, kwargs): # noqa : C901
3011+
# ne_evaluate will need safe_blosc2_globals for some functions (e.g. clip, logaddexp)
3012+
# that are implemenetd in python-blosc2 not in numexpr
3013+
global safe_blosc2_globals
3014+
if len(safe_blosc2_globals) == 0:
3015+
# First eval call, fill blosc2_safe_globals for ne_evaluate
3016+
safe_blosc2_globals = {"blosc2": blosc2}
3017+
# Add all first-level blosc2 functions
3018+
safe_blosc2_globals.update(
3019+
{
3020+
name: getattr(blosc2, name)
3021+
for name in dir(blosc2)
3022+
if callable(getattr(blosc2, name)) and not name.startswith("_")
3023+
}
3024+
)
3025+
3026+
if any(method in self.expression for method in reducers + lin_alg_funcs):
29273027
# We have reductions in the expression (probably coming from a string lazyexpr)
29283028
# Also includes slice
29293029
_globals = get_expr_globals(self.expression)
@@ -3450,7 +3550,7 @@ def _numpy_eval_expr(expression, operands, prefer_blosc=False):
34503550
_globals = get_expr_globals(expression)
34513551
_globals |= dtype_symbols
34523552
else:
3453-
_globals = {f: getattr(np, f) for f in functions if f not in ("contains", "pow")}
3553+
_globals = safe_numpy_globals
34543554
try:
34553555
_out = eval(expression, _globals, ops)
34563556
except RuntimeWarning:

src/blosc2/linalg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010

1111
import blosc2
12-
from blosc2.ndarray import get_intersecting_chunks, slice_to_chunktuple
12+
from blosc2.ndarray import get_intersecting_chunks, npvecdot, slice_to_chunktuple
1313

1414
if TYPE_CHECKING:
1515
from collections.abc import Sequence
@@ -340,7 +340,7 @@ def vecdot(x1: blosc2.NDArray, x2: blosc2.NDArray, axis: int = -1, **kwargs) ->
340340
fast_path = kwargs.pop("fast_path", None) # for testing purposes
341341
# Added this to pass array-api tests (which use internal getitem to check results)
342342
if isinstance(x1, np.ndarray) and isinstance(x2, np.ndarray):
343-
return np.vecdot(x1, x2, axis=axis)
343+
return npvecdot(x1, x2, axis=axis)
344344

345345
x1, x2 = blosc2.asarray(x1), blosc2.asarray(x2)
346346

@@ -399,15 +399,15 @@ def vecdot(x1: blosc2.NDArray, x2: blosc2.NDArray, axis: int = -1, **kwargs) ->
399399
if fast_path: # just load everything, also handles case of 0 in shapes
400400
bx1 = x1[a_selection]
401401
bx2 = x2[b_selection]
402-
result[res_chunk] += np.vecdot(bx1, bx2, axis=axis) # handles conjugation of bx1
402+
result[res_chunk] += npvecdot(bx1, bx2, axis=axis) # handles conjugation of bx1
403403
else: # operands too big, have to go chunk-by-chunk
404404
for ochunk in range(0, a_shape_red, a_chunks_red):
405405
op_chunk = (slice(ochunk, builtins.min(ochunk + a_chunks_red, x1.shape[a_axes]), 1),)
406406
a_selection = a_selection[:a_axes] + op_chunk + a_selection[a_axes + 1 :]
407407
b_selection = b_selection[:b_axes] + op_chunk + b_selection[b_axes + 1 :]
408408
bx1 = x1[a_selection]
409409
bx2 = x2[b_selection]
410-
res = np.vecdot(bx1, bx2, axis=axis) # handles conjugation of bx1
410+
res = npvecdot(bx1, bx2, axis=axis) # handles conjugation of bx1
411411
result[res_chunk] += res
412412
return result
413413

0 commit comments

Comments
 (0)