Skip to content

Commit 73aa182

Browse files
committed
Fixed double evaluations for reductions
1 parent b4137ee commit 73aa182

File tree

2 files changed

+149
-39
lines changed

2 files changed

+149
-39
lines changed

src/blosc2/lazyexpr.py

Lines changed: 141 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
ufunc_map_1param,
5252
)
5353

54+
from .shape_utils import constructors, infer_shape, reducers
55+
5456
if not blosc2.IS_WASM:
5557
import numexpr
5658

@@ -127,12 +129,6 @@ def ne_evaluate(expression, local_dict=None, **kwargs):
127129
"V": np.bytes_,
128130
}
129131

130-
# All the available constructors and reducers necessary for the (string) expression evaluator
131-
constructors = ("arange", "linspace", "fromiter", "zeros", "ones", "empty", "full", "frombuffer")
132-
# Note that, as reshape is accepted as a method too, it should always come last in the list
133-
constructors += ("reshape",)
134-
reducers = ("sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "slice")
135-
136132
functions = [
137133
"sin",
138134
"cos",
@@ -658,6 +654,85 @@ def validate_expr(expr: str) -> None:
658654
raise ValueError(f"Invalid method name: {method}")
659655

660656

657+
def extract_and_replace_slices(expr, operands):
658+
"""
659+
Replaces all var.slice(...).slice(...) chains in expr with oN temporary variables.
660+
Infers shapes using infer_shape and creates placeholder arrays in new_ops.
661+
662+
Returns:
663+
new_expr: expression string with oN replacements
664+
new_ops: dictionary mapping variable names (original and oN) to arrays
665+
"""
666+
# Copy shapes and operands
667+
shapes = {k: () if not hasattr(v, "shape") else v.shape for k, v in operands.items()}
668+
new_ops = operands.copy() # copy dictionary
669+
670+
# Parse the expression
671+
tree = ast.parse(expr, mode="eval")
672+
673+
# Mapping of AST nodes to new variable names
674+
replacements = {}
675+
676+
class SliceCollector(ast.NodeTransformer):
677+
def visit_Call(self, node):
678+
# Recursively visit children first
679+
self.generic_visit(node)
680+
681+
# Detect method calls: obj.slice(...)
682+
if isinstance(node.func, ast.Attribute) and node.func.attr == "slice":
683+
obj = node.func.value
684+
685+
# If the object is already replaced, keep the replacement
686+
base_name = None
687+
if isinstance(obj, ast.Name):
688+
base_name = obj.id
689+
elif isinstance(obj, ast.Call) and obj in replacements:
690+
base_name = replacements[obj]["base_var"]
691+
692+
# Build the full slice chain expression as a string
693+
full_expr = ast.unparse(node)
694+
695+
# Create a new temporary variable
696+
new_var = f"o{len(new_ops)}"
697+
698+
# Infer shape
699+
try:
700+
shape = infer_shape(full_expr, shapes)
701+
except Exception as e:
702+
print(f"⚠️ Shape inference failed for {full_expr}: {e}")
703+
shape = ()
704+
705+
# Determine dtype
706+
dtype = new_ops[base_name].dtype if base_name else None
707+
708+
# Create placeholder array
709+
if isinstance(new_ops[base_name], blosc2.NDArray):
710+
new_op = blosc2.ones((1,) * len(shape), dtype=dtype)
711+
else:
712+
new_op = np.ones((1,) * len(shape), dtype=dtype)
713+
714+
new_ops[new_var] = new_op
715+
shapes[new_var] = shape
716+
717+
# Record replacement
718+
replacements[node] = {"new_var": new_var, "base_var": base_name}
719+
720+
# Replace the AST node with the new variable
721+
return ast.Name(id=new_var, ctx=ast.Load())
722+
723+
return node
724+
725+
# Transform the AST
726+
transformer = SliceCollector()
727+
new_tree = transformer.visit(tree)
728+
ast.fix_missing_locations(new_tree)
729+
730+
# Convert back to expression string
731+
new_expr = ast.unparse(new_tree)
732+
733+
return new_expr, new_ops
734+
735+
661736
def get_expr_operands(expression: str) -> set:
662737
"""
663738
Given an expression in string form, return its operands.
@@ -2174,7 +2249,7 @@ def fuse_expressions(expr, new_base, dup_op):
21742249
return new_expr
21752250

21762251

2177-
def infer_dtype(op, value1, value2):
2252+
def check_dtype(op, value1, value2):
21782253
if op == "contains":
21792254
return np.dtype(np.bool_)
21802255

@@ -2262,7 +2337,7 @@ def __init__(self, new_op): # noqa: C901
22622337
self.operands = {}
22632338
return
22642339
value1, op, value2 = new_op
2265-
dtype_ = infer_dtype(op, value1, value2) # perform some checks
2340+
dtype_ = check_dtype(op, value1, value2) # perform some checks
22662341
if value2 is None:
22672342
if isinstance(value1, LazyExpr):
22682343
self.expression = f"{op}({value1.expression})"
@@ -2443,25 +2518,7 @@ def dtype(self):
24432518
if any(v is None for v in self.operands.values()):
24442519
return None
24452520

2446-
operands = {
2447-
key: np.ones(np.ones(len(value.shape), dtype=int), dtype=value.dtype)
2448-
if hasattr(value, "shape")
2449-
else value
2450-
for key, value in self.operands.items()
2451-
}
2452-
2453-
if "contains" in self.expression:
2454-
_out = ne_evaluate(self.expression, local_dict=operands)
2455-
else:
2456-
# Create a globals dict with the functions of numpy
2457-
globals_dict = {f: getattr(np, f) for f in functions if f not in ("contains", "pow")}
2458-
try:
2459-
_out = eval(self.expression, globals_dict, operands)
2460-
except RuntimeWarning:
2461-
# Sometimes, numpy gets a RuntimeWarning when evaluating expressions
2462-
# with synthetic operands (1's). Let's try with numexpr, which is not so picky
2463-
# about this.
2464-
_out = ne_evaluate(self.expression, local_dict=operands)
2521+
_out = _numpy_eval_expr(self.expression, self.operands, prefer_blosc=False)
24652522
self._dtype_ = _out.dtype
24662523
self._expression_ = self.expression
24672524
return self._dtype_
@@ -2501,9 +2558,13 @@ def shape(self):
25012558
def chunks(self):
25022559
if hasattr(self, "_chunks"):
25032560
return self._chunks
2504-
self._shape, self._chunks, self._blocks, fast_path = validate_inputs(
2561+
shape, self._chunks, self._blocks, fast_path = validate_inputs(
25052562
self.operands, getattr(self, "_out", None)
25062563
)
2564+
if not hasattr(self, "_shape"):
2565+
self._shape = shape
2566+
if self._shape != shape: # validate inputs only works for elementwise funcs so returned shape might
2567+
fast_path = False # be incompatible with true output shape
25072568
if not fast_path:
25082569
# Not using the fast path, so we need to compute the chunks/blocks automatically
25092570
self._chunks, self._blocks = compute_chunks_blocks(self.shape, None, None, dtype=self.dtype)
@@ -2513,9 +2574,13 @@ def chunks(self):
25132574
def blocks(self):
25142575
if hasattr(self, "_blocks"):
25152576
return self._blocks
2516-
self._shape, self._chunks, self._blocks, fast_path = validate_inputs(
2577+
shape, self._chunks, self._blocks, fast_path = validate_inputs(
25172578
self.operands, getattr(self, "_out", None)
25182579
)
2580+
if not hasattr(self, "_shape"):
2581+
self._shape = shape
2582+
if self._shape != shape: # validate inputs only works for elementwise funcs so returned shape might
2583+
fast_path = False # be incompatible with true output shape
25192584
if not fast_path:
25202585
# Not using the fast path, so we need to compute the chunks/blocks automatically
25212586
self._chunks, self._blocks = compute_chunks_blocks(self.shape, None, None, dtype=self.dtype)
@@ -3105,15 +3170,19 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
31053170
# Most in particular, castings like np.int8 et al. can be very useful to allow
31063171
# for desired data types in the output.
31073172
_operands = operands | local_vars
3108-
_globals = get_expr_globals(expression)
3109-
_globals |= dtype_symbols
31103173
# Check that operands are proper Operands, LazyArray or scalars; if not, convert to NDArray objects
31113174
for op, val in _operands.items():
31123175
if not (isinstance(val, (blosc2.Operand, blosc2.LazyArray, np.ndarray)) or np.isscalar(val)):
31133176
_operands[op] = blosc2.SimpleProxy(val)
3114-
new_expr = eval(_expression, _globals, _operands)
3177+
# for scalars just return value (internally converts to () if necessary)
3178+
opshapes = {k: v if not hasattr(v, "shape") else v.shape for k, v in _operands.items()}
3179+
_shape = infer_shape(_expression, opshapes) # infer shape, includes constructors
3180+
# substitutes with numpy operands (cheap for reductions) and
3181+
# defaults to blosc2 functions (cheap for constructors)
3182+
# have to handle slices since a[10] on a dummy variable of shape (1,1) doesn't work
3183+
desliced_expr, desliced_ops = extract_and_replace_slices(_expression, _operands)
3184+
new_expr = _numpy_eval_expr(desliced_expr, desliced_ops, prefer_blosc=True)
31153185
_dtype = new_expr.dtype
3116-
_shape = new_expr.shape
31173186
if isinstance(new_expr, blosc2.LazyExpr):
31183187
# DO NOT restore the original expression and operands
31193188
# Instead rebase operands and restore only constructors
@@ -3139,15 +3208,16 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
31393208
new_expr.operands = operands_
31403209
new_expr.operands_tosave = operands
31413210
elif isinstance(new_expr, blosc2.NDArray) and len(operands) == 1:
3142-
# passed either "a" or possible "a[:10]"
3211+
# passed "a", "a[:10]", 'sum(a)'
31433212
expression_, operands_ = conserve_functions(
31443213
_expression, _operands, {"o0": list(operands.values())[0]} | local_vars
31453214
)
31463215
new_expr = cls(None)
31473216
new_expr.expression = expression_
31483217
new_expr.operands = operands_
31493218
else:
3150-
# An immediate evaluation happened (e.g. all operands are numpy arrays)
3219+
# An immediate evaluation happened
3220+
# (e.g. all operands are numpy arrays or constructors)
31513221
new_expr = cls(None)
31523222
new_expr.expression = expression
31533223
new_expr.operands = operands
@@ -3348,6 +3418,42 @@ def save(self, **kwargs):
33483418
raise NotImplementedError("For safety reasons, this is not implemented for UDFs")
33493419

33503420

3421+
def _numpy_eval_expr(expression, operands, prefer_blosc=False):
3422+
ops = (
3423+
{
3424+
key: blosc2.ones((1,) * len(value.shape), dtype=value.dtype)
3425+
if hasattr(value, "chunks")
3426+
else value
3427+
for key, value in operands.items()
3428+
}
3429+
if prefer_blosc
3430+
else {
3431+
key: np.ones(np.ones(len(value.shape), dtype=int), dtype=value.dtype)
3432+
if hasattr(value, "shape")
3433+
else value
3434+
for key, value in operands.items()
3435+
}
3436+
)
3437+
if "contains" in expression:
3438+
_out = ne_evaluate(expression, local_dict=ops)
3439+
else:
3440+
# Create a globals dict with the functions of blosc2 preferentially
3441+
# (and numpy if can't find blosc2)
3442+
if prefer_blosc:
3443+
_globals = get_expr_globals(expression)
3444+
_globals |= dtype_symbols
3445+
else:
3446+
_globals = {f: getattr(np, f) for f in functions if f not in ("contains", "pow")}
3447+
try:
3448+
_out = eval(expression, _globals, ops)
3449+
except RuntimeWarning:
3450+
# Sometimes, numpy gets a RuntimeWarning when evaluating expressions
3451+
# with synthetic operands (1's). Let's try with numexpr, which is not so picky
3452+
# about this.
3453+
_out = ne_evaluate(expression, local_dict=ops)
3454+
return _out
3455+
3456+
33513457
def lazyudf(
33523458
func: Callable[[tuple, np.ndarray, tuple[int]], None],
33533459
inputs: Sequence[Any] | None,

tests/ndarray/test_reductions.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -491,15 +491,15 @@ def test_slice_lazy():
491491
def test_slicebrackets_lazy():
492492
shape = (20, 20)
493493
a = blosc2.linspace(0, 20, num=np.prod(shape), shape=shape)
494-
arr = blosc2.lazyexpr("anarr[10:15] + 1", {"anarr": a})
494+
arr = blosc2.lazyexpr("sum(anarr[10:15], axis=0) + anarr[10:15] + arange(20) + 1", {"anarr": a})
495495
newarr = arr.compute()
496-
np.testing.assert_allclose(newarr[:], a[10:15] + 1)
496+
np.testing.assert_allclose(newarr[:], np.sum(a[10:15], axis=0) + a[10:15] + np.arange(20) + 1)
497497

498498
# Try with getitem
499499
a = blosc2.linspace(0, 20, num=np.prod(shape), shape=shape)
500-
arr = blosc2.lazyexpr("anarr[10:15] + 1", {"anarr": a})
500+
arr = blosc2.lazyexpr("sum(anarr[10:15], axis=0) + anarr[10:15] + arange(20) + 1", {"anarr": a})
501501
newarr = arr[:3]
502-
res = a[10:15] + 1
502+
res = np.sum(a[10:15], axis=0) + a[10:15] + np.arange(20) + 1
503503
np.testing.assert_allclose(newarr, res[:3])
504504

505505
# Test other cases
@@ -511,6 +511,10 @@ def test_slicebrackets_lazy():
511511
newarr = arr.compute()
512512
np.testing.assert_allclose(newarr[:], a[10:15][2:9] + 1)
513513

514+
arr = blosc2.lazyexpr("sum(anarr[10:15], axis=1) + 1", {"anarr": a})
515+
newarr = arr.compute()
516+
np.testing.assert_allclose(newarr[:], np.sum(a[10:15], axis=1) + 1)
517+
514518
arr = blosc2.lazyexpr("anarr[10] + 1", {"anarr": a})
515519
newarr = arr.compute()
516520
np.testing.assert_allclose(newarr[:], a[10] + 1)

0 commit comments

Comments
 (0)