41
41
from blosc2 import compute_chunks_blocks
42
42
from blosc2 .info import InfoReporter
43
43
from blosc2 .ndarray import (
44
+ NUMPY_GE_2_0 ,
44
45
_check_allowed_dtypes ,
45
46
get_chunks_idx ,
46
47
get_intersecting_chunks ,
53
54
54
55
from .shape_utils import constructors , infer_shape , lin_alg_funcs , reducers
55
56
56
- not_numexpr_funcs = lin_alg_funcs + ("clip" , "logaddexp" )
57
-
58
57
if not blosc2 .IS_WASM :
59
58
import numexpr
60
59
60
+ global safe_blosc2_globals
61
+ safe_blosc2_globals = {}
62
+ global safe_numpy_globals
63
+ # Use numpy eval when running in WebAssembly
64
+ safe_numpy_globals = {"np" : np }
65
+ # Add all first-level numpy functions
66
+ safe_numpy_globals .update (
67
+ {name : getattr (np , name ) for name in dir (np ) if callable (getattr (np , name )) and not name .startswith ("_" )}
68
+ )
69
+
70
+ if not NUMPY_GE_2_0 : # handle non-array-api compliance
71
+ safe_numpy_globals ["acos" ] = np .arccos
72
+ safe_numpy_globals ["acosh" ] = np .arccosh
73
+ safe_numpy_globals ["asin" ] = np .arcsin
74
+ safe_numpy_globals ["asinh" ] = np .arcsinh
75
+ safe_numpy_globals ["atan" ] = np .arctan
76
+ safe_numpy_globals ["atanh" ] = np .arctanh
77
+ safe_numpy_globals ["atan2" ] = np .arctan2
78
+ safe_numpy_globals ["permute_dims" ] = np .transpose
79
+ safe_numpy_globals ["pow" ] = np .power
80
+ safe_numpy_globals ["bitwise_left_shift" ] = np .left_shift
81
+ safe_numpy_globals ["bitwise_right_shift" ] = np .right_shift
82
+ safe_numpy_globals ["bitwise_invert" ] = np .bitwise_not
83
+ safe_numpy_globals ["concat" ] = np .concatenate
84
+ safe_numpy_globals ["matrix_transpose" ] = np .transpose
85
+ safe_numpy_globals ["vecdot" ] = blosc2 .ndarray .npvecdot
86
+
61
87
62
88
def ne_evaluate (expression , local_dict = None , ** kwargs ):
63
89
"""Safely evaluate expressions using numexpr when possible, falling back to numpy."""
@@ -76,22 +102,24 @@ def ne_evaluate(expression, local_dict=None, **kwargs):
76
102
)
77
103
}
78
104
if blosc2 .IS_WASM :
79
- # Use numpy eval when running in WebAssembly
80
- safe_globals = {"np" : np }
81
- # Add all first-level numpy functions
82
- safe_globals .update (
83
- {
84
- name : getattr (np , name )
85
- for name in dir (np )
86
- if callable (getattr (np , name )) and not name .startswith ("_" )
87
- }
88
- )
105
+ global safe_numpy_globals
106
+ if "out" in kwargs :
107
+ out = kwargs .pop ("out" )
108
+ out [:] = eval (expression , safe_numpy_globals , local_dict )
109
+ return out
110
+ return eval (expression , safe_numpy_globals , local_dict )
111
+ try :
112
+ return numexpr .evaluate (expression , local_dict = local_dict , ** kwargs )
113
+ except ValueError as e :
114
+ raise e # unsafe expression
115
+ except Exception : # non_numexpr functions present
116
+ global safe_blosc2_globals
117
+ res = eval (expression , safe_blosc2_globals , local_dict )
89
118
if "out" in kwargs :
90
119
out = kwargs .pop ("out" )
91
- out [:] = eval ( expression , safe_globals , local_dict )
120
+ out [:] = res [()] if isinstance ( res , blosc2 . LazyArray ) else res
92
121
return out
93
- return eval (expression , safe_globals , local_dict )
94
- return numexpr .evaluate (expression , local_dict = local_dict , ** kwargs )
122
+ return res [()] if isinstance (res , blosc2 .LazyArray ) else res
95
123
96
124
97
125
# Define empty ndindex tuple for function defaults
@@ -131,56 +159,116 @@ def ne_evaluate(expression, local_dict=None, **kwargs):
131
159
"V" : np .bytes_ ,
132
160
}
133
161
134
- functions = [
135
- "sin" ,
136
- "cos" ,
137
- "tan" ,
138
- "sqrt" ,
139
- "sinh" ,
140
- "cosh" ,
141
- "tanh" ,
142
- "arcsin" ,
162
+ blosc2_funcs = [
163
+ "abs" ,
164
+ "acos" ,
165
+ "acosh" ,
166
+ "add" ,
167
+ "all" ,
168
+ "any" ,
169
+ "arange" ,
143
170
"arccos" ,
171
+ "arccosh" ,
172
+ "arcsin" ,
173
+ "arcsinh" ,
144
174
"arctan" ,
145
175
"arctan2" ,
146
- "arcsinh" ,
147
- "arccosh" ,
148
176
"arctanh" ,
149
- "exp" ,
177
+ "asin" ,
178
+ "asinh" ,
179
+ "atan" ,
180
+ "atan2" ,
181
+ "atanh" ,
182
+ "bitwise_and" ,
183
+ "bitwise_invert" ,
184
+ "bitwise_left_shift" ,
185
+ "bitwise_or" ,
186
+ "bitwise_right_shift" ,
187
+ "bitwise_xor" ,
188
+ "broadcast_to" ,
189
+ "ceil" ,
190
+ "clip" ,
191
+ "concat" ,
192
+ "concatenate" ,
193
+ "copy" ,
194
+ "copysign" ,
195
+ "count_nonzero" ,
196
+ "divide" ,
197
+ "empty" ,
198
+ "empty_like" ,
199
+ "equal" ,
200
+ "expand_dims" ,
150
201
"expm1" ,
202
+ "eye" ,
203
+ "floor" ,
204
+ "floor_divide" ,
205
+ "frombuffer" ,
206
+ "fromiter" ,
207
+ "full" ,
208
+ "full_like" ,
209
+ "greater" ,
210
+ "greater_equal" ,
211
+ "hypot" ,
212
+ "isfinite" ,
213
+ "isinf" ,
214
+ "isnan" ,
215
+ "less_equal" ,
216
+ "less_than" ,
217
+ "linspace" ,
151
218
"log" ,
152
- "log10" ,
153
219
"log1p" ,
154
220
"log2" ,
155
- "conj" ,
156
- "real" ,
157
- "imag" ,
158
- "contains" ,
159
- "abs" ,
160
- "sum" ,
161
- "prod" ,
162
- "mean" ,
163
- "std" ,
164
- "var" ,
165
- "min" ,
221
+ "log10" ,
222
+ "logaddexp" ,
223
+ "logical_and" ,
224
+ "logical_not" ,
225
+ "logical_or" ,
226
+ "logical_xor" ,
227
+ "matmul" ,
228
+ "matrix_transpose" ,
166
229
"max" ,
167
- "any" ,
168
- "all" ,
169
- "pow" if np .__version__ .startswith ("2." ) else "power" ,
170
- "where" ,
171
- "isnan" ,
172
- "isfinite" ,
173
- "isinf" ,
174
- "nextafter" ,
175
- "copysign" ,
176
- "hypot" ,
177
230
"maximum" ,
231
+ "mean" ,
232
+ "meshgrid" ,
233
+ "min" ,
178
234
"minimum" ,
179
- "floor" ,
180
- "ceil" ,
181
- "trunc" ,
182
- "signbit" ,
235
+ "multiply" ,
236
+ "nans" ,
237
+ "ndarray_from_cframe" ,
238
+ "negative" ,
239
+ "nextafter" ,
240
+ "not_equal" ,
241
+ "ones" ,
242
+ "ones_like" ,
243
+ "permute_dims" ,
244
+ "positive" ,
245
+ "pow" ,
246
+ "prod" ,
247
+ "real" ,
248
+ "reciprocal" ,
249
+ "remainder" ,
250
+ "reshape" ,
183
251
"round" ,
252
+ "sign" ,
253
+ "signbit" ,
254
+ "sort" ,
255
+ "square" ,
256
+ "squeeze" ,
257
+ "stack" ,
258
+ "sum" ,
259
+ "subtract" ,
260
+ "take" ,
261
+ "take_along_axis" ,
262
+ "tan" ,
263
+ "tanh" ,
264
+ "tensordot" ,
265
+ "transpose" ,
266
+ "trunc" ,
267
+ "var" ,
268
+ "vecdot" ,
269
+ "where" ,
270
+ "zeros" ,
271
+ "zeros_like" ,
184
272
]
185
273
186
274
# Gather all callable functions in numpy
@@ -192,10 +280,8 @@ def ne_evaluate(expression, local_dict=None, **kwargs):
192
280
numpy_ufuncs = {name for name , member in inspect .getmembers (np , lambda x : isinstance (x , np .ufunc ))}
193
281
# Add these functions to the list of available functions
194
282
# (will be evaluated via the array interface)
195
- additional_funcs = sorted ((numpy_funcs | numpy_ufuncs ) - set (functions ))
196
- functions += additional_funcs
197
-
198
- functions += constructors
283
+ additional_funcs = sorted ((numpy_funcs | numpy_ufuncs ) - set (blosc2_funcs ))
284
+ functions = blosc2_funcs + additional_funcs
199
285
200
286
relational_ops = ["==" , "!=" , "<" , "<=" , ">" , ">=" ]
201
287
logical_ops = ["&" , "|" , "^" , "~" ]
@@ -204,8 +290,7 @@ def ne_evaluate(expression, local_dict=None, **kwargs):
204
290
205
291
def get_expr_globals (expression ):
206
292
"""Build a dictionary of functions needed for evaluating the expression."""
207
- _globals = {}
208
-
293
+ _globals = {"np" : np , "blosc2" : blosc2 }
209
294
# Only check for functions that actually appear in the expression
210
295
# This avoids many unnecessary string searches
211
296
for func in functions :
@@ -2922,8 +3007,23 @@ def find_args(expr):
2922
3007
2923
3008
return value , expression [idx :idx2 ]
2924
3009
2925
- def _compute_expr (self , item , kwargs ):
2926
- if any (method in self .expression for method in reducers + not_numexpr_funcs ):
3010
+ def _compute_expr (self , item , kwargs ): # noqa : C901
3011
+ # ne_evaluate will need safe_blosc2_globals for some functions (e.g. clip, logaddexp)
3012
+ # that are implemenetd in python-blosc2 not in numexpr
3013
+ global safe_blosc2_globals
3014
+ if len (safe_blosc2_globals ) == 0 :
3015
+ # First eval call, fill blosc2_safe_globals for ne_evaluate
3016
+ safe_blosc2_globals = {"blosc2" : blosc2 }
3017
+ # Add all first-level blosc2 functions
3018
+ safe_blosc2_globals .update (
3019
+ {
3020
+ name : getattr (blosc2 , name )
3021
+ for name in dir (blosc2 )
3022
+ if callable (getattr (blosc2 , name )) and not name .startswith ("_" )
3023
+ }
3024
+ )
3025
+
3026
+ if any (method in self .expression for method in reducers + lin_alg_funcs ):
2927
3027
# We have reductions in the expression (probably coming from a string lazyexpr)
2928
3028
# Also includes slice
2929
3029
_globals = get_expr_globals (self .expression )
@@ -3450,7 +3550,7 @@ def _numpy_eval_expr(expression, operands, prefer_blosc=False):
3450
3550
_globals = get_expr_globals (expression )
3451
3551
_globals |= dtype_symbols
3452
3552
else :
3453
- _globals = { f : getattr ( np , f ) for f in functions if f not in ( "contains" , "pow" )}
3553
+ _globals = safe_numpy_globals
3454
3554
try :
3455
3555
_out = eval (expression , _globals , ops )
3456
3556
except RuntimeWarning :
0 commit comments