51
51
ufunc_map_1param ,
52
52
)
53
53
54
+ from .shape_utils import constructors , infer_shape , reducers
55
+
54
56
if not blosc2 .IS_WASM :
55
57
import numexpr
56
58
@@ -127,12 +129,6 @@ def ne_evaluate(expression, local_dict=None, **kwargs):
127
129
"V" : np .bytes_ ,
128
130
}
129
131
130
- # All the available constructors and reducers necessary for the (string) expression evaluator
131
- constructors = ("arange" , "linspace" , "fromiter" , "zeros" , "ones" , "empty" , "full" , "frombuffer" )
132
- # Note that, as reshape is accepted as a method too, it should always come last in the list
133
- constructors += ("reshape" ,)
134
- reducers = ("sum" , "prod" , "min" , "max" , "std" , "mean" , "var" , "any" , "all" , "slice" )
135
-
136
132
functions = [
137
133
"sin" ,
138
134
"cos" ,
@@ -658,6 +654,85 @@ def validate_expr(expr: str) -> None:
658
654
raise ValueError (f"Invalid method name: { method } " )
659
655
660
656
657
+ def extract_and_replace_slices (expr , operands ):
658
+ """
659
+ Replaces all var.slice(...).slice(...) chains in expr with oN temporary variables.
660
+ Infers shapes using infer_shape and creates placeholder arrays in new_ops.
661
+
662
+ Returns:
663
+ new_expr: expression string with oN replacements
664
+ new_ops: dictionary mapping variable names (original and oN) to arrays
665
+ """
666
+ # Copy shapes and operands
667
+ shapes = {k : () if not hasattr (v , "shape" ) else v .shape for k , v in operands .items ()}
668
+ new_ops = operands .copy () # copy dictionary
669
+
670
+ # Parse the expression
671
+ tree = ast .parse (expr , mode = "eval" )
672
+
673
+ # Mapping of AST nodes to new variable names
674
+ replacements = {}
675
+
676
+ class SliceCollector (ast .NodeTransformer ):
677
+ def visit_Call (self , node ):
678
+ # Recursively visit children first
679
+ self .generic_visit (node )
680
+
681
+ # Detect method calls: obj.slice(...)
682
+ if isinstance (node .func , ast .Attribute ) and node .func .attr == "slice" :
683
+ obj = node .func .value
684
+
685
+ # If the object is already replaced, keep the replacement
686
+ base_name = None
687
+ if isinstance (obj , ast .Name ):
688
+ base_name = obj .id
689
+ elif isinstance (obj , ast .Call ) and obj in replacements :
690
+ base_name = replacements [obj ]["base_var" ]
691
+
692
+ # Build the full slice chain expression as a string
693
+ full_expr = ast .unparse (node )
694
+
695
+ # Create a new temporary variable
696
+ new_var = f"o{ len (new_ops )} "
697
+
698
+ # Infer shape
699
+ try :
700
+ shape = infer_shape (full_expr , shapes )
701
+ except Exception as e :
702
+ print (f"⚠️ Shape inference failed for { full_expr } : { e } " )
703
+ shape = ()
704
+
705
+ # Determine dtype
706
+ dtype = new_ops [base_name ].dtype if base_name else None
707
+
708
+ # Create placeholder array
709
+ if isinstance (new_ops [base_name ], blosc2 .NDArray ):
710
+ new_op = blosc2 .ones ((1 ,) * len (shape ), dtype = dtype )
711
+ else :
712
+ new_op = np .ones ((1 ,) * len (shape ), dtype = dtype )
713
+
714
+ new_ops [new_var ] = new_op
715
+ shapes [new_var ] = shape
716
+
717
+ # Record replacement
718
+ replacements [node ] = {"new_var" : new_var , "base_var" : base_name }
719
+
720
+ # Replace the AST node with the new variable
721
+ return ast .Name (id = new_var , ctx = ast .Load ())
722
+
723
+ return node
724
+
725
+ # Transform the AST
726
+ transformer = SliceCollector ()
727
+ new_tree = transformer .visit (tree )
728
+ ast .fix_missing_locations (new_tree )
729
+
730
+ # Convert back to expression string
731
+ new_expr = ast .unparse (new_tree )
732
+
733
+ return new_expr , new_ops
734
+
735
+
661
736
def get_expr_operands (expression : str ) -> set :
662
737
"""
663
738
Given an expression in string form, return its operands.
@@ -2174,7 +2249,7 @@ def fuse_expressions(expr, new_base, dup_op):
2174
2249
return new_expr
2175
2250
2176
2251
2177
- def infer_dtype (op , value1 , value2 ):
2252
+ def check_dtype (op , value1 , value2 ):
2178
2253
if op == "contains" :
2179
2254
return np .dtype (np .bool_ )
2180
2255
@@ -2262,7 +2337,7 @@ def __init__(self, new_op): # noqa: C901
2262
2337
self .operands = {}
2263
2338
return
2264
2339
value1 , op , value2 = new_op
2265
- dtype_ = infer_dtype (op , value1 , value2 ) # perform some checks
2340
+ dtype_ = check_dtype (op , value1 , value2 ) # perform some checks
2266
2341
if value2 is None :
2267
2342
if isinstance (value1 , LazyExpr ):
2268
2343
self .expression = f"{ op } ({ value1 .expression } )"
@@ -2443,25 +2518,7 @@ def dtype(self):
2443
2518
if any (v is None for v in self .operands .values ()):
2444
2519
return None
2445
2520
2446
- operands = {
2447
- key : np .ones (np .ones (len (value .shape ), dtype = int ), dtype = value .dtype )
2448
- if hasattr (value , "shape" )
2449
- else value
2450
- for key , value in self .operands .items ()
2451
- }
2452
-
2453
- if "contains" in self .expression :
2454
- _out = ne_evaluate (self .expression , local_dict = operands )
2455
- else :
2456
- # Create a globals dict with the functions of numpy
2457
- globals_dict = {f : getattr (np , f ) for f in functions if f not in ("contains" , "pow" )}
2458
- try :
2459
- _out = eval (self .expression , globals_dict , operands )
2460
- except RuntimeWarning :
2461
- # Sometimes, numpy gets a RuntimeWarning when evaluating expressions
2462
- # with synthetic operands (1's). Let's try with numexpr, which is not so picky
2463
- # about this.
2464
- _out = ne_evaluate (self .expression , local_dict = operands )
2521
+ _out = _numpy_eval_expr (self .expression , self .operands , prefer_blosc = False )
2465
2522
self ._dtype_ = _out .dtype
2466
2523
self ._expression_ = self .expression
2467
2524
return self ._dtype_
@@ -2501,9 +2558,13 @@ def shape(self):
2501
2558
def chunks (self ):
2502
2559
if hasattr (self , "_chunks" ):
2503
2560
return self ._chunks
2504
- self . _shape , self ._chunks , self ._blocks , fast_path = validate_inputs (
2561
+ shape , self ._chunks , self ._blocks , fast_path = validate_inputs (
2505
2562
self .operands , getattr (self , "_out" , None )
2506
2563
)
2564
+ if not hasattr (self , "_shape" ):
2565
+ self ._shape = shape
2566
+ if self ._shape != shape : # validate inputs only works for elementwise funcs so returned shape might
2567
+ fast_path = False # be incompatible with true output shape
2507
2568
if not fast_path :
2508
2569
# Not using the fast path, so we need to compute the chunks/blocks automatically
2509
2570
self ._chunks , self ._blocks = compute_chunks_blocks (self .shape , None , None , dtype = self .dtype )
@@ -2513,9 +2574,13 @@ def chunks(self):
2513
2574
def blocks (self ):
2514
2575
if hasattr (self , "_blocks" ):
2515
2576
return self ._blocks
2516
- self . _shape , self ._chunks , self ._blocks , fast_path = validate_inputs (
2577
+ shape , self ._chunks , self ._blocks , fast_path = validate_inputs (
2517
2578
self .operands , getattr (self , "_out" , None )
2518
2579
)
2580
+ if not hasattr (self , "_shape" ):
2581
+ self ._shape = shape
2582
+ if self ._shape != shape : # validate inputs only works for elementwise funcs so returned shape might
2583
+ fast_path = False # be incompatible with true output shape
2519
2584
if not fast_path :
2520
2585
# Not using the fast path, so we need to compute the chunks/blocks automatically
2521
2586
self ._chunks , self ._blocks = compute_chunks_blocks (self .shape , None , None , dtype = self .dtype )
@@ -3105,15 +3170,19 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
3105
3170
# Most in particular, castings like np.int8 et al. can be very useful to allow
3106
3171
# for desired data types in the output.
3107
3172
_operands = operands | local_vars
3108
- _globals = get_expr_globals (expression )
3109
- _globals |= dtype_symbols
3110
3173
# Check that operands are proper Operands, LazyArray or scalars; if not, convert to NDArray objects
3111
3174
for op , val in _operands .items ():
3112
3175
if not (isinstance (val , (blosc2 .Operand , blosc2 .LazyArray , np .ndarray )) or np .isscalar (val )):
3113
3176
_operands [op ] = blosc2 .SimpleProxy (val )
3114
- new_expr = eval (_expression , _globals , _operands )
3177
+ # for scalars just return value (internally converts to () if necessary)
3178
+ opshapes = {k : v if not hasattr (v , "shape" ) else v .shape for k , v in _operands .items ()}
3179
+ _shape = infer_shape (_expression , opshapes ) # infer shape, includes constructors
3180
+ # substitutes with numpy operands (cheap for reductions) and
3181
+ # defaults to blosc2 functions (cheap for constructors)
3182
+ # have to handle slices since a[10] on a dummy variable of shape (1,1) doesn't work
3183
+ desliced_expr , desliced_ops = extract_and_replace_slices (_expression , _operands )
3184
+ new_expr = _numpy_eval_expr (desliced_expr , desliced_ops , prefer_blosc = True )
3115
3185
_dtype = new_expr .dtype
3116
- _shape = new_expr .shape
3117
3186
if isinstance (new_expr , blosc2 .LazyExpr ):
3118
3187
# DO NOT restore the original expression and operands
3119
3188
# Instead rebase operands and restore only constructors
@@ -3139,15 +3208,16 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
3139
3208
new_expr .operands = operands_
3140
3209
new_expr .operands_tosave = operands
3141
3210
elif isinstance (new_expr , blosc2 .NDArray ) and len (operands ) == 1 :
3142
- # passed either "a" or possible "a[:10]"
3211
+ # passed "a", "a[:10]", 'sum(a)'
3143
3212
expression_ , operands_ = conserve_functions (
3144
3213
_expression , _operands , {"o0" : list (operands .values ())[0 ]} | local_vars
3145
3214
)
3146
3215
new_expr = cls (None )
3147
3216
new_expr .expression = expression_
3148
3217
new_expr .operands = operands_
3149
3218
else :
3150
- # An immediate evaluation happened (e.g. all operands are numpy arrays)
3219
+ # An immediate evaluation happened
3220
+ # (e.g. all operands are numpy arrays or constructors)
3151
3221
new_expr = cls (None )
3152
3222
new_expr .expression = expression
3153
3223
new_expr .operands = operands
@@ -3348,6 +3418,42 @@ def save(self, **kwargs):
3348
3418
raise NotImplementedError ("For safety reasons, this is not implemented for UDFs" )
3349
3419
3350
3420
3421
+ def _numpy_eval_expr (expression , operands , prefer_blosc = False ):
3422
+ ops = (
3423
+ {
3424
+ key : blosc2 .ones ((1 ,) * len (value .shape ), dtype = value .dtype )
3425
+ if hasattr (value , "chunks" )
3426
+ else value
3427
+ for key , value in operands .items ()
3428
+ }
3429
+ if prefer_blosc
3430
+ else {
3431
+ key : np .ones (np .ones (len (value .shape ), dtype = int ), dtype = value .dtype )
3432
+ if hasattr (value , "shape" )
3433
+ else value
3434
+ for key , value in operands .items ()
3435
+ }
3436
+ )
3437
+ if "contains" in expression :
3438
+ _out = ne_evaluate (expression , local_dict = ops )
3439
+ else :
3440
+ # Create a globals dict with the functions of blosc2 preferentially
3441
+ # (and numpy if can't find blosc2)
3442
+ if prefer_blosc :
3443
+ _globals = get_expr_globals (expression )
3444
+ _globals |= dtype_symbols
3445
+ else :
3446
+ _globals = {f : getattr (np , f ) for f in functions if f not in ("contains" , "pow" )}
3447
+ try :
3448
+ _out = eval (expression , _globals , ops )
3449
+ except RuntimeWarning :
3450
+ # Sometimes, numpy gets a RuntimeWarning when evaluating expressions
3451
+ # with synthetic operands (1's). Let's try with numexpr, which is not so picky
3452
+ # about this.
3453
+ _out = ne_evaluate (expression , local_dict = ops )
3454
+ return _out
3455
+
3456
+
3351
3457
def lazyudf (
3352
3458
func : Callable [[tuple , np .ndarray , tuple [int ]], None ],
3353
3459
inputs : Sequence [Any ] | None ,
0 commit comments