Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 145 additions & 37 deletions src/blosc2/lazyexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
ufunc_map_1param,
)

from .shape_utils import constructors, infer_shape, lin_alg_funcs, reducers

lin_alg_funcs += ("clip", "logaddexp")

if not blosc2.IS_WASM:
import numexpr

Expand Down Expand Up @@ -127,12 +131,6 @@ def ne_evaluate(expression, local_dict=None, **kwargs):
"V": np.bytes_,
}

# All the available constructors and reducers necessary for the (string) expression evaluator
constructors = ("arange", "linspace", "fromiter", "zeros", "ones", "empty", "full", "frombuffer")
# Note that, as reshape is accepted as a method too, it should always come last in the list
constructors += ("reshape",)
reducers = ("sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "slice")

functions = [
"sin",
"cos",
Expand Down Expand Up @@ -658,6 +656,85 @@ def validate_expr(expr: str) -> None:
raise ValueError(f"Invalid method name: {method}")


def extract_and_replace_slices(expr, operands):
"""
Replaces all var.slice(...).slice(...) chains in expr with oN temporary variables.
Infers shapes using infer_shape and creates placeholder arrays in new_ops.

Returns:
new_expr: expression string with oN replacements
new_ops: dictionary mapping variable names (original and oN) to arrays
"""
# Copy shapes and operands
shapes = {k: () if not hasattr(v, "shape") else v.shape for k, v in operands.items()}
new_ops = operands.copy() # copy dictionary

# Parse the expression
tree = ast.parse(expr, mode="eval")

# Mapping of AST nodes to new variable names
replacements = {}

class SliceCollector(ast.NodeTransformer):
def visit_Call(self, node):
# Recursively visit children first
self.generic_visit(node)

# Detect method calls: obj.slice(...)
if isinstance(node.func, ast.Attribute) and node.func.attr == "slice":
obj = node.func.value

# If the object is already replaced, keep the replacement
base_name = None
if isinstance(obj, ast.Name):
base_name = obj.id
elif isinstance(obj, ast.Call) and obj in replacements:
base_name = replacements[obj]["base_var"]

# Build the full slice chain expression as a string
full_expr = ast.unparse(node)

# Create a new temporary variable
new_var = f"o{len(new_ops)}"

# Infer shape
try:
shape = infer_shape(full_expr, shapes)
except Exception as e:
print(f"⚠️ Shape inference failed for {full_expr}: {e}")
shape = ()

# Determine dtype
dtype = new_ops[base_name].dtype if base_name else None

# Create placeholder array
if isinstance(new_ops[base_name], blosc2.NDArray):
new_op = blosc2.ones((1,) * len(shape), dtype=dtype)
else:
new_op = np.ones((1,) * len(shape), dtype=dtype)

new_ops[new_var] = new_op
shapes[new_var] = shape

# Record replacement
replacements[node] = {"new_var": new_var, "base_var": base_name}

# Replace the AST node with the new variable
return ast.Name(id=new_var, ctx=ast.Load())

return node

# Transform the AST
transformer = SliceCollector()
new_tree = transformer.visit(tree)
ast.fix_missing_locations(new_tree)

# Convert back to expression string
new_expr = ast.unparse(new_tree)

return new_expr, new_ops


def get_expr_operands(expression: str) -> set:
"""
Given an expression in string form, return its operands.
Expand Down Expand Up @@ -2174,7 +2251,7 @@ def fuse_expressions(expr, new_base, dup_op):
return new_expr


def infer_dtype(op, value1, value2):
def check_dtype(op, value1, value2):
if op == "contains":
return np.dtype(np.bool_)

Expand Down Expand Up @@ -2262,7 +2339,7 @@ def __init__(self, new_op): # noqa: C901
self.operands = {}
return
value1, op, value2 = new_op
dtype_ = infer_dtype(op, value1, value2) # perform some checks
dtype_ = check_dtype(op, value1, value2) # perform some checks
if value2 is None:
if isinstance(value1, LazyExpr):
self.expression = f"{op}({value1.expression})"
Expand Down Expand Up @@ -2443,25 +2520,7 @@ def dtype(self):
if any(v is None for v in self.operands.values()):
return None

operands = {
key: np.ones(np.ones(len(value.shape), dtype=int), dtype=value.dtype)
if hasattr(value, "shape")
else value
for key, value in self.operands.items()
}

if "contains" in self.expression:
_out = ne_evaluate(self.expression, local_dict=operands)
else:
# Create a globals dict with the functions of numpy
globals_dict = {f: getattr(np, f) for f in functions if f not in ("contains", "pow")}
try:
_out = eval(self.expression, globals_dict, operands)
except RuntimeWarning:
# Sometimes, numpy gets a RuntimeWarning when evaluating expressions
# with synthetic operands (1's). Let's try with numexpr, which is not so picky
# about this.
_out = ne_evaluate(self.expression, local_dict=operands)
_out = _numpy_eval_expr(self.expression, self.operands, prefer_blosc=False)
self._dtype_ = _out.dtype
self._expression_ = self.expression
return self._dtype_
Expand Down Expand Up @@ -2501,9 +2560,13 @@ def shape(self):
def chunks(self):
if hasattr(self, "_chunks"):
return self._chunks
self._shape, self._chunks, self._blocks, fast_path = validate_inputs(
shape, self._chunks, self._blocks, fast_path = validate_inputs(
self.operands, getattr(self, "_out", None)
)
if not hasattr(self, "_shape"):
self._shape = shape
if self._shape != shape: # validate inputs only works for elementwise funcs so returned shape might
fast_path = False # be incompatible with true output shape
if not fast_path:
# Not using the fast path, so we need to compute the chunks/blocks automatically
self._chunks, self._blocks = compute_chunks_blocks(self.shape, None, None, dtype=self.dtype)
Expand All @@ -2513,9 +2576,13 @@ def chunks(self):
def blocks(self):
if hasattr(self, "_blocks"):
return self._blocks
self._shape, self._chunks, self._blocks, fast_path = validate_inputs(
shape, self._chunks, self._blocks, fast_path = validate_inputs(
self.operands, getattr(self, "_out", None)
)
if not hasattr(self, "_shape"):
self._shape = shape
if self._shape != shape: # validate inputs only works for elementwise funcs so returned shape might
fast_path = False # be incompatible with true output shape
if not fast_path:
# Not using the fast path, so we need to compute the chunks/blocks automatically
self._chunks, self._blocks = compute_chunks_blocks(self.shape, None, None, dtype=self.dtype)
Expand Down Expand Up @@ -2848,13 +2915,13 @@ def find_args(expr):
return value, expression[idx:idx2]

def _compute_expr(self, item, kwargs):
if any(method in self.expression for method in reducers):
if any(method in self.expression for method in reducers + lin_alg_funcs):
# We have reductions in the expression (probably coming from a string lazyexpr)
# Also includes slice
_globals = get_expr_globals(self.expression)
lazy_expr = eval(self.expression, _globals, self.operands)
if not isinstance(lazy_expr, blosc2.LazyExpr):
key, mask = process_key(item, self.shape)
key, mask = process_key(item, lazy_expr.shape)
# An immediate evaluation happened (e.g. all operands are numpy arrays)
if hasattr(self, "_where_args"):
# We need to apply the where() operation
Expand Down Expand Up @@ -3105,15 +3172,19 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
# Most in particular, castings like np.int8 et al. can be very useful to allow
# for desired data types in the output.
_operands = operands | local_vars
_globals = get_expr_globals(expression)
_globals |= dtype_symbols
# Check that operands are proper Operands, LazyArray or scalars; if not, convert to NDArray objects
for op, val in _operands.items():
if not (isinstance(val, (blosc2.Operand, blosc2.LazyArray, np.ndarray)) or np.isscalar(val)):
_operands[op] = blosc2.SimpleProxy(val)
new_expr = eval(_expression, _globals, _operands)
# for scalars just return value (internally converts to () if necessary)
opshapes = {k: v if not hasattr(v, "shape") else v.shape for k, v in _operands.items()}
_shape = infer_shape(_expression, opshapes) # infer shape, includes constructors
# substitutes with numpy operands (cheap for reductions) and
# defaults to blosc2 functions (cheap for constructors)
# have to handle slices since a[10] on a dummy variable of shape (1,1) doesn't work
desliced_expr, desliced_ops = extract_and_replace_slices(_expression, _operands)
new_expr = _numpy_eval_expr(desliced_expr, desliced_ops, prefer_blosc=True)
_dtype = new_expr.dtype
_shape = new_expr.shape
if isinstance(new_expr, blosc2.LazyExpr):
# DO NOT restore the original expression and operands
# Instead rebase operands and restore only constructors
Expand All @@ -3139,15 +3210,16 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
new_expr.operands = operands_
new_expr.operands_tosave = operands
elif isinstance(new_expr, blosc2.NDArray) and len(operands) == 1:
# passed either "a" or possible "a[:10]"
# passed "a", "a[:10]", 'sum(a)'
expression_, operands_ = conserve_functions(
_expression, _operands, {"o0": list(operands.values())[0]} | local_vars
)
new_expr = cls(None)
new_expr.expression = expression_
new_expr.operands = operands_
else:
# An immediate evaluation happened (e.g. all operands are numpy arrays)
# An immediate evaluation happened
# (e.g. all operands are numpy arrays or constructors)
new_expr = cls(None)
new_expr.expression = expression
new_expr.operands = operands
Expand Down Expand Up @@ -3348,6 +3420,42 @@ def save(self, **kwargs):
raise NotImplementedError("For safety reasons, this is not implemented for UDFs")


def _numpy_eval_expr(expression, operands, prefer_blosc=False):
ops = (
{
key: blosc2.ones((1,) * len(value.shape), dtype=value.dtype)
if hasattr(value, "chunks")
else value
for key, value in operands.items()
}
if prefer_blosc
else {
key: np.ones(np.ones(len(value.shape), dtype=int), dtype=value.dtype)
if hasattr(value, "shape")
else value
for key, value in operands.items()
}
)
if "contains" in expression:
_out = ne_evaluate(expression, local_dict=ops)
else:
# Create a globals dict with the functions of blosc2 preferentially
# (and numpy if can't find blosc2)
if prefer_blosc:
_globals = get_expr_globals(expression)
_globals |= dtype_symbols
else:
_globals = {f: getattr(np, f) for f in functions if f not in ("contains", "pow")}
try:
_out = eval(expression, _globals, ops)
except RuntimeWarning:
# Sometimes, numpy gets a RuntimeWarning when evaluating expressions
# with synthetic operands (1's). Let's try with numexpr, which is not so picky
# about this.
_out = ne_evaluate(expression, local_dict=ops)
return _out


def lazyudf(
func: Callable[[tuple, np.ndarray, tuple[int]], None],
inputs: Sequence[Any] | None,
Expand Down
Loading
Loading