From 73aa182ee8801a84733943219b4dba325df70d17 Mon Sep 17 00:00:00 2001 From: lshaw8317 Date: Sat, 4 Oct 2025 18:42:56 +0200 Subject: [PATCH 01/10] Fixed double evaluations for reductions --- src/blosc2/lazyexpr.py | 176 +++++++++++++++++++++++++------ tests/ndarray/test_reductions.py | 12 ++- 2 files changed, 149 insertions(+), 39 deletions(-) diff --git a/src/blosc2/lazyexpr.py b/src/blosc2/lazyexpr.py index 5f33c99e..fabbaeb4 100644 --- a/src/blosc2/lazyexpr.py +++ b/src/blosc2/lazyexpr.py @@ -51,6 +51,8 @@ ufunc_map_1param, ) +from .shape_utils import constructors, infer_shape, reducers + if not blosc2.IS_WASM: import numexpr @@ -127,12 +129,6 @@ def ne_evaluate(expression, local_dict=None, **kwargs): "V": np.bytes_, } -# All the available constructors and reducers necessary for the (string) expression evaluator -constructors = ("arange", "linspace", "fromiter", "zeros", "ones", "empty", "full", "frombuffer") -# Note that, as reshape is accepted as a method too, it should always come last in the list -constructors += ("reshape",) -reducers = ("sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "slice") - functions = [ "sin", "cos", @@ -658,6 +654,85 @@ def validate_expr(expr: str) -> None: raise ValueError(f"Invalid method name: {method}") +def extract_and_replace_slices(expr, operands): + """ + Replaces all var.slice(...).slice(...) chains in expr with oN temporary variables. + Infers shapes using infer_shape and creates placeholder arrays in new_ops. + + Returns: + new_expr: expression string with oN replacements + new_ops: dictionary mapping variable names (original and oN) to arrays + """ + # Copy shapes and operands + shapes = {k: () if not hasattr(v, "shape") else v.shape for k, v in operands.items()} + new_ops = operands.copy() # copy dictionary + + # Parse the expression + tree = ast.parse(expr, mode="eval") + + # Mapping of AST nodes to new variable names + replacements = {} + + class SliceCollector(ast.NodeTransformer): + def visit_Call(self, node): + # Recursively visit children first + self.generic_visit(node) + + # Detect method calls: obj.slice(...) + if isinstance(node.func, ast.Attribute) and node.func.attr == "slice": + obj = node.func.value + + # If the object is already replaced, keep the replacement + base_name = None + if isinstance(obj, ast.Name): + base_name = obj.id + elif isinstance(obj, ast.Call) and obj in replacements: + base_name = replacements[obj]["base_var"] + + # Build the full slice chain expression as a string + full_expr = ast.unparse(node) + + # Create a new temporary variable + new_var = f"o{len(new_ops)}" + + # Infer shape + try: + shape = infer_shape(full_expr, shapes) + except Exception as e: + print(f"⚠️ Shape inference failed for {full_expr}: {e}") + shape = () + + # Determine dtype + dtype = new_ops[base_name].dtype if base_name else None + + # Create placeholder array + if isinstance(new_ops[base_name], blosc2.NDArray): + new_op = blosc2.ones((1,) * len(shape), dtype=dtype) + else: + new_op = np.ones((1,) * len(shape), dtype=dtype) + + new_ops[new_var] = new_op + shapes[new_var] = shape + + # Record replacement + replacements[node] = {"new_var": new_var, "base_var": base_name} + + # Replace the AST node with the new variable + return ast.Name(id=new_var, ctx=ast.Load()) + + return node + + # Transform the AST + transformer = SliceCollector() + new_tree = transformer.visit(tree) + ast.fix_missing_locations(new_tree) + + # Convert back to expression string + new_expr = ast.unparse(new_tree) + + return new_expr, new_ops + + def get_expr_operands(expression: str) -> set: """ Given an expression in string form, return its operands. @@ -2174,7 +2249,7 @@ def fuse_expressions(expr, new_base, dup_op): return new_expr -def infer_dtype(op, value1, value2): +def check_dtype(op, value1, value2): if op == "contains": return np.dtype(np.bool_) @@ -2262,7 +2337,7 @@ def __init__(self, new_op): # noqa: C901 self.operands = {} return value1, op, value2 = new_op - dtype_ = infer_dtype(op, value1, value2) # perform some checks + dtype_ = check_dtype(op, value1, value2) # perform some checks if value2 is None: if isinstance(value1, LazyExpr): self.expression = f"{op}({value1.expression})" @@ -2443,25 +2518,7 @@ def dtype(self): if any(v is None for v in self.operands.values()): return None - operands = { - key: np.ones(np.ones(len(value.shape), dtype=int), dtype=value.dtype) - if hasattr(value, "shape") - else value - for key, value in self.operands.items() - } - - if "contains" in self.expression: - _out = ne_evaluate(self.expression, local_dict=operands) - else: - # Create a globals dict with the functions of numpy - globals_dict = {f: getattr(np, f) for f in functions if f not in ("contains", "pow")} - try: - _out = eval(self.expression, globals_dict, operands) - except RuntimeWarning: - # Sometimes, numpy gets a RuntimeWarning when evaluating expressions - # with synthetic operands (1's). Let's try with numexpr, which is not so picky - # about this. - _out = ne_evaluate(self.expression, local_dict=operands) + _out = _numpy_eval_expr(self.expression, self.operands, prefer_blosc=False) self._dtype_ = _out.dtype self._expression_ = self.expression return self._dtype_ @@ -2501,9 +2558,13 @@ def shape(self): def chunks(self): if hasattr(self, "_chunks"): return self._chunks - self._shape, self._chunks, self._blocks, fast_path = validate_inputs( + shape, self._chunks, self._blocks, fast_path = validate_inputs( self.operands, getattr(self, "_out", None) ) + if not hasattr(self, "_shape"): + self._shape = shape + if self._shape != shape: # validate inputs only works for elementwise funcs so returned shape might + fast_path = False # be incompatible with true output shape if not fast_path: # Not using the fast path, so we need to compute the chunks/blocks automatically self._chunks, self._blocks = compute_chunks_blocks(self.shape, None, None, dtype=self.dtype) @@ -2513,9 +2574,13 @@ def chunks(self): def blocks(self): if hasattr(self, "_blocks"): return self._blocks - self._shape, self._chunks, self._blocks, fast_path = validate_inputs( + shape, self._chunks, self._blocks, fast_path = validate_inputs( self.operands, getattr(self, "_out", None) ) + if not hasattr(self, "_shape"): + self._shape = shape + if self._shape != shape: # validate inputs only works for elementwise funcs so returned shape might + fast_path = False # be incompatible with true output shape if not fast_path: # Not using the fast path, so we need to compute the chunks/blocks automatically self._chunks, self._blocks = compute_chunks_blocks(self.shape, None, None, dtype=self.dtype) @@ -3105,15 +3170,19 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No # Most in particular, castings like np.int8 et al. can be very useful to allow # for desired data types in the output. _operands = operands | local_vars - _globals = get_expr_globals(expression) - _globals |= dtype_symbols # Check that operands are proper Operands, LazyArray or scalars; if not, convert to NDArray objects for op, val in _operands.items(): if not (isinstance(val, (blosc2.Operand, blosc2.LazyArray, np.ndarray)) or np.isscalar(val)): _operands[op] = blosc2.SimpleProxy(val) - new_expr = eval(_expression, _globals, _operands) + # for scalars just return value (internally converts to () if necessary) + opshapes = {k: v if not hasattr(v, "shape") else v.shape for k, v in _operands.items()} + _shape = infer_shape(_expression, opshapes) # infer shape, includes constructors + # substitutes with numpy operands (cheap for reductions) and + # defaults to blosc2 functions (cheap for constructors) + # have to handle slices since a[10] on a dummy variable of shape (1,1) doesn't work + desliced_expr, desliced_ops = extract_and_replace_slices(_expression, _operands) + new_expr = _numpy_eval_expr(desliced_expr, desliced_ops, prefer_blosc=True) _dtype = new_expr.dtype - _shape = new_expr.shape if isinstance(new_expr, blosc2.LazyExpr): # DO NOT restore the original expression and operands # Instead rebase operands and restore only constructors @@ -3139,7 +3208,7 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No new_expr.operands = operands_ new_expr.operands_tosave = operands elif isinstance(new_expr, blosc2.NDArray) and len(operands) == 1: - # passed either "a" or possible "a[:10]" + # passed "a", "a[:10]", 'sum(a)' expression_, operands_ = conserve_functions( _expression, _operands, {"o0": list(operands.values())[0]} | local_vars ) @@ -3147,7 +3216,8 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No new_expr.expression = expression_ new_expr.operands = operands_ else: - # An immediate evaluation happened (e.g. all operands are numpy arrays) + # An immediate evaluation happened + # (e.g. all operands are numpy arrays or constructors) new_expr = cls(None) new_expr.expression = expression new_expr.operands = operands @@ -3348,6 +3418,42 @@ def save(self, **kwargs): raise NotImplementedError("For safety reasons, this is not implemented for UDFs") +def _numpy_eval_expr(expression, operands, prefer_blosc=False): + ops = ( + { + key: blosc2.ones((1,) * len(value.shape), dtype=value.dtype) + if hasattr(value, "chunks") + else value + for key, value in operands.items() + } + if prefer_blosc + else { + key: np.ones(np.ones(len(value.shape), dtype=int), dtype=value.dtype) + if hasattr(value, "shape") + else value + for key, value in operands.items() + } + ) + if "contains" in expression: + _out = ne_evaluate(expression, local_dict=ops) + else: + # Create a globals dict with the functions of blosc2 preferentially + # (and numpy if can't find blosc2) + if prefer_blosc: + _globals = get_expr_globals(expression) + _globals |= dtype_symbols + else: + _globals = {f: getattr(np, f) for f in functions if f not in ("contains", "pow")} + try: + _out = eval(expression, _globals, ops) + except RuntimeWarning: + # Sometimes, numpy gets a RuntimeWarning when evaluating expressions + # with synthetic operands (1's). Let's try with numexpr, which is not so picky + # about this. + _out = ne_evaluate(expression, local_dict=ops) + return _out + + def lazyudf( func: Callable[[tuple, np.ndarray, tuple[int]], None], inputs: Sequence[Any] | None, diff --git a/tests/ndarray/test_reductions.py b/tests/ndarray/test_reductions.py index d362bfbf..dbaa7997 100644 --- a/tests/ndarray/test_reductions.py +++ b/tests/ndarray/test_reductions.py @@ -491,15 +491,15 @@ def test_slice_lazy(): def test_slicebrackets_lazy(): shape = (20, 20) a = blosc2.linspace(0, 20, num=np.prod(shape), shape=shape) - arr = blosc2.lazyexpr("anarr[10:15] + 1", {"anarr": a}) + arr = blosc2.lazyexpr("sum(anarr[10:15], axis=0) + anarr[10:15] + arange(20) + 1", {"anarr": a}) newarr = arr.compute() - np.testing.assert_allclose(newarr[:], a[10:15] + 1) + np.testing.assert_allclose(newarr[:], np.sum(a[10:15], axis=0) + a[10:15] + np.arange(20) + 1) # Try with getitem a = blosc2.linspace(0, 20, num=np.prod(shape), shape=shape) - arr = blosc2.lazyexpr("anarr[10:15] + 1", {"anarr": a}) + arr = blosc2.lazyexpr("sum(anarr[10:15], axis=0) + anarr[10:15] + arange(20) + 1", {"anarr": a}) newarr = arr[:3] - res = a[10:15] + 1 + res = np.sum(a[10:15], axis=0) + a[10:15] + np.arange(20) + 1 np.testing.assert_allclose(newarr, res[:3]) # Test other cases @@ -511,6 +511,10 @@ def test_slicebrackets_lazy(): newarr = arr.compute() np.testing.assert_allclose(newarr[:], a[10:15][2:9] + 1) + arr = blosc2.lazyexpr("sum(anarr[10:15], axis=1) + 1", {"anarr": a}) + newarr = arr.compute() + np.testing.assert_allclose(newarr[:], np.sum(a[10:15], axis=1) + 1) + arr = blosc2.lazyexpr("anarr[10] + 1", {"anarr": a}) newarr = arr.compute() np.testing.assert_allclose(newarr[:], a[10] + 1) From 7424ce81e039ad8d0daba19f73f50e667e567181 Mon Sep 17 00:00:00 2001 From: lshaw8317 Date: Sat, 4 Oct 2025 19:01:39 +0200 Subject: [PATCH 02/10] Add shape_utils.py file --- src/blosc2/shape_utils.py | 267 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 267 insertions(+) create mode 100644 src/blosc2/shape_utils.py diff --git a/src/blosc2/shape_utils.py b/src/blosc2/shape_utils.py new file mode 100644 index 00000000..4de77eb1 --- /dev/null +++ b/src/blosc2/shape_utils.py @@ -0,0 +1,267 @@ +import ast + +from numpy import broadcast_shapes + +reducers = ("sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "slice") + +# All the available constructors and reducers necessary for the (string) expression evaluator +constructors = ( + "arange", + "linspace", + "fromiter", + "zeros", + "ones", + "empty", + "full", + "frombuffer", + "full_like", + "zeros_like", + "ones_like", + "empty_like", +) +# Note that, as reshape is accepted as a method too, it should always come last in the list +constructors += ("reshape",) + + +# --- Shape utilities --- +def reduce_shape(shape, axis, keepdims): + """Reduce shape along given axis or axes (collapse dimensions).""" + if shape is None: + return None # unknown shape + + # full reduction + if axis is None: + return (1,) * len(shape) if keepdims else () + + # normalize to tuple + if isinstance(axis, int): + axes = (axis,) + else: + axes = tuple(axis) + + # normalize negative axes + axes = tuple(a + len(shape) if a < 0 else a for a in axes) + + if keepdims: + return tuple(d if i not in axes else 1 for i, d in enumerate(shape)) + else: + return tuple(d for i, d in enumerate(shape) if i not in axes) + + +def slice_shape(shape, slices): + """Infer shape after slicing.""" + result = [] + for dim, sl in zip(shape, slices, strict=False): + if isinstance(sl, int): # indexing removes the axis + continue + if isinstance(sl, slice): + start = sl.start or 0 + stop = sl.stop if sl.stop is not None else dim + step = sl.step or 1 + length = max(0, (stop - start + (step - 1)) // step) + result.append(length) + else: + raise ValueError(f"Unsupported slice type: {sl}") + result.extend(shape[len(slices) :]) # untouched trailing dims + return tuple(result) + + +def elementwise(*args): + """All args must broadcast elementwise.""" + shape = args[0] + shape = shape if shape is not None else () + for s in args[1:]: + shape = broadcast_shapes(shape, s) if s is not None else shape + return shape + + +# --- Function registry --- +FUNCTIONS = { # ignore out arg + func: lambda x, axis=None, keepdims=False, out=None: reduce_shape(x, axis, keepdims) + for func in reducers + # any unknown function will default to elementwise +} + + +# --- AST Shape Inferencer --- +class ShapeInferencer(ast.NodeVisitor): + def __init__(self, shapes): + self.shapes = shapes + + def visit_Name(self, node): + if node.id not in self.shapes: + raise ValueError(f"Unknown symbol: {node.id}") + s = self.shapes[node.id] + if isinstance(s, tuple): + return s + else: # passed a scalar value + return () + + def visit_Call(self, node): # noqa : C901 + func_name = getattr(node.func, "id", None) + attr_name = getattr(node.func, "attr", None) + + # --- Recursive method-chain support --- + obj_shape = None + if isinstance(node.func, ast.Attribute): + obj_shape = self.visit(node.func.value) + + # --- Parse keyword args --- + kwargs = {} + for kw in node.keywords: + if isinstance(kw.value, ast.Constant): + kwargs[kw.arg] = kw.value.value + elif isinstance(kw.value, ast.Tuple): + kwargs[kw.arg] = tuple( + e.value if isinstance(e, ast.Constant) else self._lookup_value(e) for e in kw.value.elts + ) + else: + kwargs[kw.arg] = self._lookup_value(kw.value) + + # ------- handle constructors --------------- + if func_name in constructors or attr_name == "reshape": + # shape kwarg directly provided + if "shape" in kwargs: + val = kwargs["shape"] + return val if isinstance(val, tuple) else (val,) + + # ---- array constructors like zeros, ones, full, etc. ---- + elif func_name in ( + "zeros", + "ones", + "empty", + "full", + "full_like", + "zeros_like", + "empty_like", + "ones_like", + ): + if node.args: + shape_arg = node.args[0] + if isinstance(shape_arg, ast.Tuple): + shape = tuple(self._const_or_lookup(e) for e in shape_arg.elts) + elif isinstance(shape_arg, ast.Constant): + shape = (shape_arg.value,) + else: + shape = self._lookup_value(shape_arg) + shape = shape if isinstance(shape, tuple) else (shape,) + return shape + + # ---- arange ---- + elif func_name == "arange": + start = self._const_or_lookup(node.args[0]) if node.args else 0 + stop = self._const_or_lookup(node.args[1]) if len(node.args) > 1 else None + step = self._const_or_lookup(node.args[2]) if len(node.args) > 2 else 1 + shape = self._const_or_lookup(node.args[4]) if len(node.args) > 4 else kwargs.get("shape") + + if shape is not None: + return shape if isinstance(shape, tuple) else (shape,) + + # Fallback to numeric difference if possible + if stop is None: + stop, start = start, 0 + try: + NUM = int((stop - start) / step) + except Exception: + # symbolic or non-numeric: unknown 1D + return ((),) + return (max(NUM, 0),) + + # ---- linspace ---- + elif func_name == "linspace": + num = self._const_or_lookup(node.args[2]) if len(node.args) > 2 else kwargs.get("num") + shape = self._const_or_lookup(node.args[5]) if len(node.args) > 5 else kwargs.get("shape") + if shape is not None: + return shape if isinstance(shape, tuple) else (shape,) + if num is not None: + return (num,) + raise ValueError("linspace requires either shape or num argument") + + elif func_name == "frombuffer" or func_name == "fromiter": + count = kwargs.get("count") + return (count,) if count else () + + elif func_name == "reshape" or attr_name == "reshape": + if node.args: + shape_arg = node.args[-1] + if isinstance(shape_arg, ast.Tuple): + return tuple(self._const_or_lookup(e) for e in shape_arg.elts) + return () + + else: + raise ValueError(f"Unrecognized constructor or missing shape argument for {func_name}") + + # --- Special-case .slice((slice(...), ...)) --- + if attr_name == "slice": + if not node.args: + raise ValueError(".slice() requires an argument") + slice_arg = node.args[0] + if isinstance(slice_arg, ast.Tuple): + slices = [self._eval_slice(s) for s in slice_arg.elts] + else: + slices = [self._eval_slice(slice_arg)] + return slice_shape(obj_shape, slices) + + # --- Evaluate argument shapes normally --- + args = [self.visit(arg) for arg in node.args] + + if func_name in FUNCTIONS: + return FUNCTIONS[func_name](*args, **kwargs) + if attr_name in FUNCTIONS: + return FUNCTIONS[attr_name](obj_shape, **kwargs) + + shapes = [obj_shape] + args if obj_shape is not None else args + shapes = [s for s in shapes if s is not None] + return elementwise(*shapes) if shapes else () + + def visit_Compare(self, node): + shapes = [self.visit(node.left)] + [self.visit(c) for c in node.comparators] + return elementwise(*shapes) + + def visit_BinOp(self, node): + left = self.visit(node.left) + right = self.visit(node.right) + left = () if left is None else left + right = () if right is None else right + return broadcast_shapes(left, right) + + def _eval_slice(self, node): + if isinstance(node, ast.Slice): + return slice( + node.lower.value if node.lower else None, + node.upper.value if node.upper else None, + node.step.value if node.step else None, + ) + elif isinstance(node, ast.Call) and getattr(node.func, "id", None) == "slice": + # handle explicit slice() constructor + args = [a.value if isinstance(a, ast.Constant) else None for a in node.args] + return slice(*args) + elif isinstance(node, ast.Constant): + return node.value + else: + raise ValueError(f"Unsupported slice expression: {ast.dump(node)}") + + def _lookup_value(self, node): + """Look up a value in self.shapes if node is a variable name, else constant value.""" + if isinstance(node, ast.Name): + return self.shapes.get(node.id, None) + elif isinstance(node, ast.Constant): + return node.value + else: + return None + + def _const_or_lookup(self, node): + """Return constant value or resolve name to scalar from shapes.""" + if isinstance(node, ast.Constant): + return node.value + elif isinstance(node, ast.Name): + return self.shapes.get(node.id, None) + else: + return None + + +# --- Public API --- +def infer_shape(expr, shapes): + tree = ast.parse(expr, mode="eval") + inferencer = ShapeInferencer(shapes) + return inferencer.visit(tree.body) From 71967b6e7c9241e99cd291db173508b13fb7095d Mon Sep 17 00:00:00 2001 From: lshaw8317 Date: Sun, 5 Oct 2025 14:52:07 +0200 Subject: [PATCH 03/10] Enable lazy imperative matmul etc. --- src/blosc2/lazyexpr.py | 8 +++-- src/blosc2/shape_utils.py | 66 ++++++++++++++++++++-------------- tests/ndarray/test_lazyexpr.py | 16 +++++++++ 3 files changed, 61 insertions(+), 29 deletions(-) diff --git a/src/blosc2/lazyexpr.py b/src/blosc2/lazyexpr.py index fabbaeb4..66faf38e 100644 --- a/src/blosc2/lazyexpr.py +++ b/src/blosc2/lazyexpr.py @@ -51,7 +51,9 @@ ufunc_map_1param, ) -from .shape_utils import constructors, infer_shape, reducers +from .shape_utils import constructors, infer_shape, lin_alg_funcs, reducers + +lin_alg_funcs += ("clip", "logaddexp") if not blosc2.IS_WASM: import numexpr @@ -2913,13 +2915,13 @@ def find_args(expr): return value, expression[idx:idx2] def _compute_expr(self, item, kwargs): - if any(method in self.expression for method in reducers): + if any(method in self.expression for method in reducers + lin_alg_funcs): # We have reductions in the expression (probably coming from a string lazyexpr) # Also includes slice _globals = get_expr_globals(self.expression) lazy_expr = eval(self.expression, _globals, self.operands) if not isinstance(lazy_expr, blosc2.LazyExpr): - key, mask = process_key(item, self.shape) + key, mask = process_key(item, lazy_expr.shape) # An immediate evaluation happened (e.g. all operands are numpy arrays) if hasattr(self, "_where_args"): # We need to apply the where() operation diff --git a/src/blosc2/shape_utils.py b/src/blosc2/shape_utils.py index 4de77eb1..ca488df7 100644 --- a/src/blosc2/shape_utils.py +++ b/src/blosc2/shape_utils.py @@ -2,7 +2,21 @@ from numpy import broadcast_shapes -reducers = ("sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "slice") +lin_alg_funcs = ( + "concat", + "diagonal", + "expand_dims", + "matmul", + "matrix_transpose", + "outer", + "permute_dims", + "squeeze", + "stack", + "tensordot", + "transpose", + "vecdot", +) +reducers = ("sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "slice", "count_nonzero") # All the available constructors and reducers necessary for the (string) expression evaluator constructors = ( @@ -18,6 +32,7 @@ "zeros_like", "ones_like", "empty_like", + "eye", ) # Note that, as reshape is accepted as a method too, it should always come last in the list constructors += ("reshape",) @@ -50,6 +65,8 @@ def reduce_shape(shape, axis, keepdims): def slice_shape(shape, slices): """Infer shape after slicing.""" + if shape is None: + return None result = [] for dim, sl in zip(shape, slices, strict=False): if isinstance(sl, int): # indexing removes the axis @@ -68,11 +85,9 @@ def slice_shape(shape, slices): def elementwise(*args): """All args must broadcast elementwise.""" - shape = args[0] - shape = shape if shape is not None else () - for s in args[1:]: - shape = broadcast_shapes(shape, s) if s is not None else shape - return shape + if None in args: + return None + return broadcast_shapes(*args) # --- Function registry --- @@ -118,6 +133,9 @@ def visit_Call(self, node): # noqa : C901 else: kwargs[kw.arg] = self._lookup_value(kw.value) + if func_name in lin_alg_funcs: + return None # need to implement shape handling for these funcs + # ------- handle constructors --------------- if func_name in constructors or attr_name == "reshape": # shape kwarg directly provided @@ -139,7 +157,7 @@ def visit_Call(self, node): # noqa : C901 if node.args: shape_arg = node.args[0] if isinstance(shape_arg, ast.Tuple): - shape = tuple(self._const_or_lookup(e) for e in shape_arg.elts) + shape = tuple(self._lookup_value(e) for e in shape_arg.elts) elif isinstance(shape_arg, ast.Constant): shape = (shape_arg.value,) else: @@ -149,10 +167,10 @@ def visit_Call(self, node): # noqa : C901 # ---- arange ---- elif func_name == "arange": - start = self._const_or_lookup(node.args[0]) if node.args else 0 - stop = self._const_or_lookup(node.args[1]) if len(node.args) > 1 else None - step = self._const_or_lookup(node.args[2]) if len(node.args) > 2 else 1 - shape = self._const_or_lookup(node.args[4]) if len(node.args) > 4 else kwargs.get("shape") + start = self._lookup_value(node.args[0]) if node.args else 0 + stop = self._lookup_value(node.args[1]) if len(node.args) > 1 else None + step = self._lookup_value(node.args[2]) if len(node.args) > 2 else 1 + shape = self._lookup_value(node.args[4]) if len(node.args) > 4 else kwargs.get("shape") if shape is not None: return shape if isinstance(shape, tuple) else (shape,) @@ -169,8 +187,8 @@ def visit_Call(self, node): # noqa : C901 # ---- linspace ---- elif func_name == "linspace": - num = self._const_or_lookup(node.args[2]) if len(node.args) > 2 else kwargs.get("num") - shape = self._const_or_lookup(node.args[5]) if len(node.args) > 5 else kwargs.get("shape") + num = self._lookup_value(node.args[2]) if len(node.args) > 2 else kwargs.get("num") + shape = self._lookup_value(node.args[5]) if len(node.args) > 5 else kwargs.get("shape") if shape is not None: return shape if isinstance(shape, tuple) else (shape,) if num is not None: @@ -180,12 +198,16 @@ def visit_Call(self, node): # noqa : C901 elif func_name == "frombuffer" or func_name == "fromiter": count = kwargs.get("count") return (count,) if count else () + elif func_name == "eye": + N = self._lookup_value(node.args[0]) + M = self._lookup_value(node.args[1]) if len(node.args) > 1 else kwargs.get("M") + return (N, N) if M is None else (N, M) elif func_name == "reshape" or attr_name == "reshape": if node.args: shape_arg = node.args[-1] if isinstance(shape_arg, ast.Tuple): - return tuple(self._const_or_lookup(e) for e in shape_arg.elts) + return tuple(self._lookup_value(e) for e in shape_arg.elts) return () else: @@ -218,12 +240,13 @@ def visit_Compare(self, node): shapes = [self.visit(node.left)] + [self.visit(c) for c in node.comparators] return elementwise(*shapes) + def visit_Constant(self, node): + return () + def visit_BinOp(self, node): left = self.visit(node.left) right = self.visit(node.right) - left = () if left is None else left - right = () if right is None else right - return broadcast_shapes(left, right) + return elementwise(left, right) def _eval_slice(self, node): if isinstance(node, ast.Slice): @@ -250,15 +273,6 @@ def _lookup_value(self, node): else: return None - def _const_or_lookup(self, node): - """Return constant value or resolve name to scalar from shapes.""" - if isinstance(node, ast.Constant): - return node.value - elif isinstance(node, ast.Name): - return self.shapes.get(node.id, None) - else: - return None - # --- Public API --- def infer_shape(expr, shapes): diff --git a/tests/ndarray/test_lazyexpr.py b/tests/ndarray/test_lazyexpr.py index e616129a..1e7a8dc6 100644 --- a/tests/ndarray/test_lazyexpr.py +++ b/tests/ndarray/test_lazyexpr.py @@ -1583,3 +1583,19 @@ def __len__(self): lb = blosc2.lazyexpr("b + c + 1") np.testing.assert_array_equal(lb[:], a + a + 1) + + +def test_not_numexpr(): + shape = (20, 20) + a = blosc2.linspace(0, 20, num=np.prod(shape), shape=shape) + b = blosc2.ones((20, 1)) + d_blosc2 = blosc2.evaluate("logaddexp(a, b) + a") + npa = a[()] + npb = b[()] + np.testing.assert_array_almost_equal(d_blosc2, np.logaddexp(npa, npb) + npa) + # TODO: Implement __add__ etc. for LazyUDF so this line works + # d_blosc2 = blosc2.evaluate(f"logaddexp(a, b) + clip(a, 6, 12)") + arr = blosc2.lazyexpr("matmul(a,b) + a ") + assert isinstance(arr, blosc2.LazyExpr) + assert arr.shape is None # can't calculate shape for linalg funcs yet + np.testing.assert_array_almost_equal(arr[()], np.matmul(npa, npb) + a) From 6e5d0cde84cfb190ccaad94f9b18260b3d6e3e66 Mon Sep 17 00:00:00 2001 From: lshaw8317 Date: Mon, 6 Oct 2025 09:44:26 +0200 Subject: [PATCH 04/10] Make sure to change numpy arrays too --- src/blosc2/lazyexpr.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/blosc2/lazyexpr.py b/src/blosc2/lazyexpr.py index 66faf38e..c10997df 100644 --- a/src/blosc2/lazyexpr.py +++ b/src/blosc2/lazyexpr.py @@ -3421,21 +3421,29 @@ def save(self, **kwargs): def _numpy_eval_expr(expression, operands, prefer_blosc=False): - ops = ( - { + if prefer_blosc: + # convert blosc arrays to small dummies + ops = { key: blosc2.ones((1,) * len(value.shape), dtype=value.dtype) if hasattr(value, "chunks") - else value + else value # some of these could be numpy arrays for key, value in operands.items() } - if prefer_blosc - else { + # change numpy arrays + ops = { + key: np.ones((1,) * len(value.shape), dtype=value.dtype) + if isinstance(value, np.ndarray) + else value + for key, value in ops.items() + } + else: + ops = { key: np.ones(np.ones(len(value.shape), dtype=int), dtype=value.dtype) if hasattr(value, "shape") else value for key, value in operands.items() } - ) + if "contains" in expression: _out = ne_evaluate(expression, local_dict=ops) else: From 6a259574e0f07979611e7e7245fb9962a98b82c0 Mon Sep 17 00:00:00 2001 From: lshaw8317 Date: Mon, 6 Oct 2025 14:16:45 +0200 Subject: [PATCH 05/10] Add shape parsing of linalg funcs --- src/blosc2/lazyexpr.py | 36 +++---- src/blosc2/linalg.py | 4 - src/blosc2/shape_utils.py | 166 ++++++++++++++++++++++++++++++++- tests/ndarray/test_lazyexpr.py | 99 +++++++++++++++++++- 4 files changed, 273 insertions(+), 32 deletions(-) diff --git a/src/blosc2/lazyexpr.py b/src/blosc2/lazyexpr.py index c10997df..bc5adcac 100644 --- a/src/blosc2/lazyexpr.py +++ b/src/blosc2/lazyexpr.py @@ -620,6 +620,11 @@ def compute_smaller_slice(larger_shape, smaller_shape, larger_slice): valid_methods |= {"int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"} valid_methods |= {"float32", "float64", "complex64", "complex128"} valid_methods |= {"bool", "str", "bytes"} +valid_methods |= { + name + for name in dir(blosc2.NDArray) + if callable(getattr(blosc2.NDArray, name)) and not name.startswith("_") +} def validate_expr(expr: str) -> None: @@ -2002,7 +2007,7 @@ def reduce_slices( # noqa: C901 continue if where is None: - if expression == "o0": + if expression == "o0" or expression == "(o0)": # We don't have an actual expression, so avoid a copy except to make contiguous result = np.require(chunk_operands["o0"], requirements="C") else: @@ -3168,9 +3173,6 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No # in guessing mode to avoid computing reductions # Extract possible numpy scalars _expression, local_vars = extract_numpy_scalars(expression) - # Let's include numpy and blosc2 as operands so that some functions can be used - # Most in particular, castings like np.int8 et al. can be very useful to allow - # for desired data types in the output. _operands = operands | local_vars # Check that operands are proper Operands, LazyArray or scalars; if not, convert to NDArray objects for op, val in _operands.items(): @@ -3179,10 +3181,10 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No # for scalars just return value (internally converts to () if necessary) opshapes = {k: v if not hasattr(v, "shape") else v.shape for k, v in _operands.items()} _shape = infer_shape(_expression, opshapes) # infer shape, includes constructors - # substitutes with numpy operands (cheap for reductions) and - # defaults to blosc2 functions (cheap for constructors) # have to handle slices since a[10] on a dummy variable of shape (1,1) doesn't work desliced_expr, desliced_ops = extract_and_replace_slices(_expression, _operands) + # substitutes with dummy operands (cheap for reductions) and + # defaults to blosc2 functions (cheap for constructors) new_expr = _numpy_eval_expr(desliced_expr, desliced_ops, prefer_blosc=True) _dtype = new_expr.dtype if isinstance(new_expr, blosc2.LazyExpr): @@ -3205,24 +3207,16 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No if counter == 0 and char == ",": break expression_ = finalexpr[:-1] # remove trailing comma - new_expr.expression = f"({expression_})" # force parenthesis - new_expr.expression_tosave = expression - new_expr.operands = operands_ - new_expr.operands_tosave = operands - elif isinstance(new_expr, blosc2.NDArray) and len(operands) == 1: - # passed "a", "a[:10]", 'sum(a)' - expression_, operands_ = conserve_functions( - _expression, _operands, {"o0": list(operands.values())[0]} | local_vars - ) - new_expr = cls(None) - new_expr.expression = expression_ - new_expr.operands = operands_ else: + new_expr = cls(None) # An immediate evaluation happened # (e.g. all operands are numpy arrays or constructors) - new_expr = cls(None) - new_expr.expression = expression - new_expr.operands = operands + # or passed "a", "a[:10]", 'sum(a)' + expression_, operands_ = conserve_functions(_expression, _operands, local_vars) + new_expr.expression = f"({expression_})" # force parenthesis + new_expr.operands = operands_ + new_expr.expression_tosave = expression + new_expr.operands_tosave = operands # Cache the dtype and shape (should be immutable) new_expr._dtype = _dtype new_expr._shape = _shape diff --git a/src/blosc2/linalg.py b/src/blosc2/linalg.py index 1ea0617e..65cc2fb9 100644 --- a/src/blosc2/linalg.py +++ b/src/blosc2/linalg.py @@ -353,10 +353,6 @@ def vecdot(x1: blosc2.NDArray, x2: blosc2.NDArray, axis: int = -1, **kwargs) -> a_keep[a_axes] = False b_keep = [True] * x2.ndim b_keep[b_axes] = False - x1shape = np.array(x1.shape) - x2shape = np.array(x2.shape) - result_shape = np.broadcast_shapes(x1shape[a_keep], x2shape[b_keep]) - result = blosc2.zeros(result_shape, dtype=np.result_type(x1, x2), **kwargs) x1shape = np.array(x1.shape) x2shape = np.array(x2.shape) diff --git a/src/blosc2/shape_utils.py b/src/blosc2/shape_utils.py index ca488df7..f0756366 100644 --- a/src/blosc2/shape_utils.py +++ b/src/blosc2/shape_utils.py @@ -1,4 +1,5 @@ import ast +import builtins from numpy import broadcast_shapes @@ -15,6 +16,8 @@ "tensordot", "transpose", "vecdot", + "T", + "mT", ) reducers = ("sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "slice", "count_nonzero") @@ -39,6 +42,152 @@ # --- Shape utilities --- +def linalg_shape(func_name, args, kwargs): # noqa: C901 + # --- Linear algebra and tensor manipulation --- + a = args[0] if args else None + if a is None or any(s is None for s in a): + return None + b = args[1] if len(args) > 1 else None + axis = kwargs.get("axis", None) + axes = kwargs.get("axes", None) + offset = kwargs.get("offset", 0) + + # --- concat --- + if func_name == "concat": + shapes = args[0] + if axis is None and len(args) > 1: + axis = args[1] + + # Coerce axis to int if tuple single-element + axis = 0 if axis is None else axis + # normalize negative axis + axis = axis + len(shapes[0]) if axis < 0 else axis + concat_dim = builtins.sum([s[axis] for s in shapes]) + return tuple(s if i != axis else concat_dim for i, s in enumerate(shapes[0])) + + # --- diagonal --- + elif func_name == "diagonal": + axis1 = len(a) - 2 + axis2 = len(a) - 1 + new_shape = [d for i, d in enumerate(a) if i not in (axis1, axis2)] + d1, d2 = a[axis1], a[axis2] + diag_len = builtins.max(0, min(d1, d2) - abs(offset)) + new_shape.append(diag_len) + return tuple(new_shape) + + # --- expand_dims --- + elif func_name == "expand_dims": + # positional axis may be second positional argument + if axis is None and len(args) > 1: + axis = args[1] + if axis is None: + axis = 0 + axis = [axis] if isinstance(axis, int) else axis + new_shape = list(a) + for ax in sorted(axis): + ax = ax if ax >= 0 else len(new_shape) + ax + 1 + new_shape.insert(ax, 1) + return tuple(new_shape) + + # --- matmul --- + elif func_name == "matmul": + if b is None: + return None + x1_is_vector = False + x2_is_vector = False + if len(a) == 1: + a = (1,) + a # (N,) -> (1, N) + x1_is_vector = True + if len(b) == 1: + b += (1,) # (M,) -> (M, 1) + x2_is_vector = True + batch = broadcast_shapes(a[:-2], b[:-2]) + shape = batch + if not x1_is_vector: + shape += a[-2] + if not x2_is_vector: + shape += b[-1] + return shape + + # --- matrix_transpose --- + elif func_name == "matrix_transpose": + if len(a) < 2: + return a + return a[:-2] + (a[-1], a[-2]) + + # --- outer --- + elif func_name == "outer": + if b is None: + return None + return a + b + + # --- permute_dims --- + elif func_name == "permute_dims": + if axes is None and len(args) > 1: + axes = args[1] + if axes is None: + axes = tuple(reversed(range(len(a)))) + return tuple(a[i] for i in axes) + + # --- squeeze --- + elif func_name == "squeeze": + if axis is None and len(args) > 1: + axis = args[1] + if axis is None: + return tuple(d for d in a if d != 1) + if isinstance(axis, int): + axis = (axis,) + axis = tuple(ax if ax >= 0 else len(a) + ax for ax in axis) + return tuple(d for i, d in enumerate(a) if i not in axis or d != 1) + + # --- stack --- + elif func_name == "stack": + # detect axis as last positional if candidate + elems = args[0] + if axis is None and len(args) > 1: + axis = args[1] + if axis is None: + axis = 0 + return elems[0][:axis] + (len(elems),) + elems[0][axis:] + + # --- tensordot --- + elif func_name == "tensordot": + if axes is None and len(args) > 2: + axes = args[2] + if axis is None: + axes = 2 + if b is None: + return None + if isinstance(axes, int): + a_rest = a[:-axes] + b_rest = b[axes:] + else: + a_axes, b_axes = axes + a_rest = tuple(d for i, d in enumerate(a) if i not in a_axes) + b_rest = tuple(d for i, d in enumerate(b) if i not in b_axes) + return a_rest + b_rest + + # --- transpose --- + elif func_name == ("transpose", "T", "mT"): + return a[:-2] + (a[-1], a[-2]) + + # --- vecdot --- + elif func_name == "vecdot": + if axis is None and len(args) > 2: + axis = args[2] + if axis is None: + axis = -1 + if b is None: + return None + a_axis = axis + len(a) + b_axis = axis + len(b) + a_rem = tuple(d for i, d in enumerate(a) if i != a_axis) + b_rem = tuple(d for i, d in enumerate(b) if i != b_axis) + return broadcast_shapes(a_rem, b_rem) + else: + return None + + def reduce_shape(shape, axis, keepdims): """Reduce shape along given axis or axes (collapse dimensions).""" if shape is None: @@ -133,8 +282,18 @@ def visit_Call(self, node): # noqa : C901 else: kwargs[kw.arg] = self._lookup_value(kw.value) + # ------- handle linear algebra --------------- + target = None if func_name in lin_alg_funcs: - return None # need to implement shape handling for these funcs + target = func_name + if attr_name in lin_alg_funcs: + target = attr_name + if target is not None: + args = [self.visit(arg) for arg in node.args] + # If it's a method call, prepend the object shape + if obj_shape is not None and attr_name == target: + args.insert(0, obj_shape) + return linalg_shape(target, args, kwargs) # ------- handle constructors --------------- if func_name in constructors or attr_name == "reshape": @@ -241,7 +400,10 @@ def visit_Compare(self, node): return elementwise(*shapes) def visit_Constant(self, node): - return () + return () if not hasattr(node.value, "shape") else node.value.shape + + def visit_Tuple(self, node): + return tuple(self.visit(arg) for arg in node.elts) def visit_BinOp(self, node): left = self.visit(node.left) diff --git a/tests/ndarray/test_lazyexpr.py b/tests/ndarray/test_lazyexpr.py index 1e7a8dc6..a006e15a 100644 --- a/tests/ndarray/test_lazyexpr.py +++ b/tests/ndarray/test_lazyexpr.py @@ -1129,10 +1129,10 @@ def test_rebasing(array_fixture): assert expr.expression == "(o0 + o1 - o2 * o3)" expr = blosc2.lazyexpr("a1") - assert expr.expression == "o0" + assert expr.expression == "(o0)" expr = blosc2.lazyexpr("a1[:10]") - assert expr.expression == "o0.slice((slice(None, 10, None),))" + assert expr.expression == "(o0.slice((slice(None, 10, None),)))" # Test get_chunk method @@ -1595,7 +1595,96 @@ def test_not_numexpr(): np.testing.assert_array_almost_equal(d_blosc2, np.logaddexp(npa, npb) + npa) # TODO: Implement __add__ etc. for LazyUDF so this line works # d_blosc2 = blosc2.evaluate(f"logaddexp(a, b) + clip(a, 6, 12)") - arr = blosc2.lazyexpr("matmul(a,b) + a ") + arr = blosc2.lazyexpr("matmul(a, b)") assert isinstance(arr, blosc2.LazyExpr) - assert arr.shape is None # can't calculate shape for linalg funcs yet - np.testing.assert_array_almost_equal(arr[()], np.matmul(npa, npb) + a) + np.testing.assert_array_almost_equal(arr[()], np.matmul(npa, npb)) + + +def test_lazylinalg(): + """ + Test the shape parser for linear algebra funcs + """ + # --- define base shapes --- + shapes = { + "A": (3, 4), + "B": (4, 5), + "C": (2, 3, 4), + "D": (1, 5, 1), + "x": (10,), + "y": (10,), + } + s = shapes["x"] + x = blosc2.linspace(0, np.prod(s), shape=s) + s = shapes["y"] + y = blosc2.linspace(0, np.prod(s), shape=s) + s = shapes["A"] + A = blosc2.linspace(0, np.prod(s), shape=s) + s = shapes["B"] + B = blosc2.linspace(0, np.prod(s), shape=s) + s = shapes["C"] + C = blosc2.linspace(0, np.prod(s), shape=s) + s = shapes["D"] + D = blosc2.linspace(0, np.prod(s), shape=s) + + npx = x[()] + npy = y[()] + npA = A[()] + + # --- concat --- + out = blosc2.lazyexpr("concat((x, y), axis=0)") + assert out.shape == np.concat((npx, npy), axis=0).shape + + # --- diagonal --- + out = blosc2.lazyexpr("diagonal(A)") + assert out.shape == np.diagonal(npA).shape + + # --- expand_dims --- + out = blosc2.lazyexpr("expand_dims(x, axis=0)") + assert out.shape == (1,) + shapes["x"] + + # --- matmul --- + out = blosc2.lazyexpr("matmul(A, B)") + assert out.shape == (shapes["A"][0], shapes["B"][1]) + + # --- matrix_transpose --- + out = blosc2.lazyexpr("matrix_transpose(A)") + assert out.shape == (shapes["A"][1], shapes["A"][0]) + + # --- outer --- + out = blosc2.lazyexpr("outer(x, y)") + assert out.shape == shapes["x"] + shapes["y"] + + # --- permute_dims --- + out = blosc2.lazyexpr("permute_dims(C, axes=(2,0,1))") + assert out.shape == (shapes["C"][2], shapes["C"][0], shapes["C"][1]) + + # --- squeeze --- + out = blosc2.lazyexpr("squeeze(D)") + assert out.shape == (5,) + out = blosc2.lazyexpr("D.squeeze()") + assert out.shape == (5,) + + # --- stack --- + out = blosc2.lazyexpr("stack((x, y), axis=0)") + assert out.shape == (2,) + shapes["x"] + + # --- tensordot --- + out = blosc2.lazyexpr("tensordot(A, B, axes=1)") + assert out.shape[0] == shapes["A"][0] + assert out.shape[-1] == shapes["B"][-1] + + # --- vecdot --- + out = blosc2.lazyexpr("vecdot(x, y)") + assert out.shape == np.vecdot(x[()], y[()]).shape + + # batched matmul + shapes = { + "A": (1, 3, 4), + "B": (3, 4, 5), + } + s = shapes["A"] + A = blosc2.linspace(0, np.prod(s), shape=s) + s = shapes["B"] + B = blosc2.linspace(0, np.prod(s), shape=s) + out = blosc2.lazyexpr("matmul(A, B)") + assert out.shape == (3, 3, 5) From 38c6bd3269a5bea0c737ec9b918f4606a22dd4aa Mon Sep 17 00:00:00 2001 From: lshaw8317 Date: Mon, 6 Oct 2025 16:14:12 +0200 Subject: [PATCH 06/10] Add Francesc's suggestions, improve tests --- src/blosc2/lazyexpr.py | 17 +++++---- src/blosc2/shape_utils.py | 23 ++++++++--- tests/ndarray/test_lazyexpr.py | 70 ++++++++++++++++++++++++++-------- 3 files changed, 83 insertions(+), 27 deletions(-) diff --git a/src/blosc2/lazyexpr.py b/src/blosc2/lazyexpr.py index bc5adcac..493b87ce 100644 --- a/src/blosc2/lazyexpr.py +++ b/src/blosc2/lazyexpr.py @@ -53,7 +53,7 @@ from .shape_utils import constructors, infer_shape, lin_alg_funcs, reducers -lin_alg_funcs += ("clip", "logaddexp") +not_numexpr_funcs = lin_alg_funcs + ("clip", "logaddexp") if not blosc2.IS_WASM: import numexpr @@ -645,8 +645,11 @@ def validate_expr(expr: str) -> None: skip_quotes = re.sub(r"(\'[^\']*\')", "", no_whitespace) # Check for forbidden patterns - if _blacklist_re.search(skip_quotes) is not None: - raise ValueError(f"'{expr}' is not a valid expression.") + forbiddens = _blacklist_re.search(skip_quotes) + if forbiddens is not None: + i = forbiddens.span()[0] + if expr[i : i + 2] != ".T" and expr[i : i + 3] != ".mT": # allow tranpose methods + raise ValueError(f"'{expr}' is not a valid expression.") # Check for invalid characters not covered by the tokenizer invalid_chars = re.compile(r"[^\w\s+\-*/%()[].,=<>!&|~^]") @@ -706,7 +709,7 @@ def visit_Call(self, node): try: shape = infer_shape(full_expr, shapes) except Exception as e: - print(f"⚠️ Shape inference failed for {full_expr}: {e}") + print(f"Shape inference failed for {full_expr}: {e}") shape = () # Determine dtype @@ -2920,7 +2923,7 @@ def find_args(expr): return value, expression[idx:idx2] def _compute_expr(self, item, kwargs): - if any(method in self.expression for method in reducers + lin_alg_funcs): + if any(method in self.expression for method in reducers + not_numexpr_funcs): # We have reductions in the expression (probably coming from a string lazyexpr) # Also includes slice _globals = get_expr_globals(self.expression) @@ -3441,8 +3444,8 @@ def _numpy_eval_expr(expression, operands, prefer_blosc=False): if "contains" in expression: _out = ne_evaluate(expression, local_dict=ops) else: - # Create a globals dict with the functions of blosc2 preferentially - # (and numpy if can't find blosc2) + # Create a globals dict with blosc2 version of functions preferentially + # (default to numpy func if not implemented in blosc2) if prefer_blosc: _globals = get_expr_globals(expression) _globals |= dtype_symbols diff --git a/src/blosc2/shape_utils.py b/src/blosc2/shape_utils.py index f0756366..3b7a4412 100644 --- a/src/blosc2/shape_utils.py +++ b/src/blosc2/shape_utils.py @@ -104,9 +104,9 @@ def linalg_shape(func_name, args, kwargs): # noqa: C901 batch = broadcast_shapes(a[:-2], b[:-2]) shape = batch if not x1_is_vector: - shape += a[-2] + shape += (a[-2],) if not x2_is_vector: - shape += b[-1] + shape += (b[-1],) return shape # --- matrix_transpose --- @@ -154,7 +154,7 @@ def linalg_shape(func_name, args, kwargs): # noqa: C901 elif func_name == "tensordot": if axes is None and len(args) > 2: axes = args[2] - if axis is None: + if axes is None: axes = 2 if b is None: return None @@ -168,7 +168,7 @@ def linalg_shape(func_name, args, kwargs): # noqa: C901 return a_rest + b_rest # --- transpose --- - elif func_name == ("transpose", "T", "mT"): + elif func_name in ("transpose", "T", "mT"): return a[:-2] + (a[-1], a[-2]) # --- vecdot --- @@ -261,9 +261,22 @@ def visit_Name(self, node): else: # passed a scalar value return () + def visit_Attribute(self, node): + obj_shape = self.visit(node.value) + attr = node.attr + if attr == "reshape": + if node.args: + shape_arg = node.args[-1] + if isinstance(shape_arg, ast.Tuple): + return tuple(self._lookup_value(e) for e in shape_arg.elts) + return () + elif attr in ("T", "mT"): + return linalg_shape(attr, (obj_shape,), {}) + return None + def visit_Call(self, node): # noqa : C901 func_name = getattr(node.func, "id", None) - attr_name = getattr(node.func, "attr", None) + attr_name = getattr(node.func, "attr", None) # handle methods called on funcs # --- Recursive method-chain support --- obj_shape = None diff --git a/tests/ndarray/test_lazyexpr.py b/tests/ndarray/test_lazyexpr.py index a006e15a..779d664c 100644 --- a/tests/ndarray/test_lazyexpr.py +++ b/tests/ndarray/test_lazyexpr.py @@ -1629,62 +1629,102 @@ def test_lazylinalg(): npx = x[()] npy = y[()] npA = A[()] + npB = B[()] + npC = C[()] + npD = D[()] # --- concat --- out = blosc2.lazyexpr("concat((x, y), axis=0)") - assert out.shape == np.concat((npx, npy), axis=0).shape + npres = np.concatenate((npx, npy), axis=0) + assert out.shape == npres.shape + np.testing.assert_array_almost_equal(out[()], npres) # --- diagonal --- out = blosc2.lazyexpr("diagonal(A)") - assert out.shape == np.diagonal(npA).shape + npres = np.diagonal(npA) + assert out.shape == npres.shape + np.testing.assert_array_almost_equal(out[()], npres) # --- expand_dims --- out = blosc2.lazyexpr("expand_dims(x, axis=0)") - assert out.shape == (1,) + shapes["x"] + npres = np.expand_dims(npx, axis=0) + assert out.shape == npres.shape + np.testing.assert_array_almost_equal(out[()], npres) # --- matmul --- out = blosc2.lazyexpr("matmul(A, B)") - assert out.shape == (shapes["A"][0], shapes["B"][1]) + npres = np.matmul(npA, npB) + assert out.shape == npres.shape + np.testing.assert_array_almost_equal(out[()], npres) # --- matrix_transpose --- out = blosc2.lazyexpr("matrix_transpose(A)") - assert out.shape == (shapes["A"][1], shapes["A"][0]) + npres = np.matrix_transpose(npA) + assert out.shape == npres.shape + np.testing.assert_array_almost_equal(out[()], npres) + out = blosc2.lazyexpr("C.mT") + npres = C.mT + assert out.shape == npres.shape + np.testing.assert_array_almost_equal(out[()], npres) + out = blosc2.lazyexpr("A.T") + npres = npA.T + assert out.shape == npres.shape + np.testing.assert_array_almost_equal(out[()], npres) # --- outer --- out = blosc2.lazyexpr("outer(x, y)") - assert out.shape == shapes["x"] + shapes["y"] + npres = np.outer(npx, npy) + assert out.shape == npres.shape + np.testing.assert_array_almost_equal(out[()], npres) # --- permute_dims --- out = blosc2.lazyexpr("permute_dims(C, axes=(2,0,1))") - assert out.shape == (shapes["C"][2], shapes["C"][0], shapes["C"][1]) + npres = np.transpose(npC, axes=(2, 0, 1)) + assert out.shape == npres.shape + np.testing.assert_array_almost_equal(out[()], npres) # --- squeeze --- out = blosc2.lazyexpr("squeeze(D)") - assert out.shape == (5,) + npres = np.squeeze(npD) + assert out.shape == npres.shape + np.testing.assert_array_almost_equal(out[()], npres) + out = blosc2.lazyexpr("D.squeeze()") - assert out.shape == (5,) + npres = np.squeeze(npD) + assert out.shape == npres.shape + np.testing.assert_array_almost_equal(out[()], npres) # --- stack --- out = blosc2.lazyexpr("stack((x, y), axis=0)") - assert out.shape == (2,) + shapes["x"] + npres = np.stack((npx, npy), axis=0) + assert out.shape == npres.shape + np.testing.assert_array_almost_equal(out[()], npres) # --- tensordot --- out = blosc2.lazyexpr("tensordot(A, B, axes=1)") - assert out.shape[0] == shapes["A"][0] - assert out.shape[-1] == shapes["B"][-1] + npres = np.tensordot(npA, npB, axes=1) + assert out.shape == npres.shape + np.testing.assert_array_almost_equal(out[()], npres) # --- vecdot --- out = blosc2.lazyexpr("vecdot(x, y)") - assert out.shape == np.vecdot(x[()], y[()]).shape + npres = np.vecdot(npx, npy) + assert out.shape == npres.shape + np.testing.assert_array_almost_equal(out[()], npres) - # batched matmul + # --- batched matmul --- shapes = { "A": (1, 3, 4), "B": (3, 4, 5), } s = shapes["A"] A = blosc2.linspace(0, np.prod(s), shape=s) + npA = A[()] # actual numpy array s = shapes["B"] B = blosc2.linspace(0, np.prod(s), shape=s) + npB = B[()] # actual numpy array + out = blosc2.lazyexpr("matmul(A, B)") - assert out.shape == (3, 3, 5) + npres = np.matmul(npA, npB) + assert out.shape == npres.shape + np.testing.assert_array_almost_equal(out[()], npres) From 78cb8640d3a0203cf73daee269a19bdea490f0dc Mon Sep 17 00:00:00 2001 From: lshaw8317 Date: Tue, 7 Oct 2025 11:37:15 +0200 Subject: [PATCH 07/10] Fixed compatibility with numpy 1.26 --- src/blosc2/lazyexpr.py | 224 ++++++++++++++++++++++++--------- src/blosc2/linalg.py | 8 +- src/blosc2/ndarray.py | 20 ++- src/blosc2/shape_utils.py | 64 +++++----- tests/ndarray/test_lazyexpr.py | 6 +- tests/ndarray/test_lazyudf.py | 9 ++ tests/ndarray/test_linalg.py | 9 +- 7 files changed, 236 insertions(+), 104 deletions(-) diff --git a/src/blosc2/lazyexpr.py b/src/blosc2/lazyexpr.py index 493b87ce..f832c77c 100644 --- a/src/blosc2/lazyexpr.py +++ b/src/blosc2/lazyexpr.py @@ -41,6 +41,7 @@ from blosc2 import compute_chunks_blocks from blosc2.info import InfoReporter from blosc2.ndarray import ( + NUMPY_GE_2_0, _check_allowed_dtypes, get_chunks_idx, get_intersecting_chunks, @@ -53,11 +54,36 @@ from .shape_utils import constructors, infer_shape, lin_alg_funcs, reducers -not_numexpr_funcs = lin_alg_funcs + ("clip", "logaddexp") - if not blosc2.IS_WASM: import numexpr +global safe_blosc2_globals +safe_blosc2_globals = {} +global safe_numpy_globals +# Use numpy eval when running in WebAssembly +safe_numpy_globals = {"np": np} +# Add all first-level numpy functions +safe_numpy_globals.update( + {name: getattr(np, name) for name in dir(np) if callable(getattr(np, name)) and not name.startswith("_")} +) + +if not NUMPY_GE_2_0: # handle non-array-api compliance + safe_numpy_globals["acos"] = np.arccos + safe_numpy_globals["acosh"] = np.arccosh + safe_numpy_globals["asin"] = np.arcsin + safe_numpy_globals["asinh"] = np.arcsinh + safe_numpy_globals["atan"] = np.arctan + safe_numpy_globals["atanh"] = np.arctanh + safe_numpy_globals["atan2"] = np.arctan2 + safe_numpy_globals["permute_dims"] = np.transpose + safe_numpy_globals["pow"] = np.power + safe_numpy_globals["bitwise_left_shift"] = np.left_shift + safe_numpy_globals["bitwise_right_shift"] = np.right_shift + safe_numpy_globals["bitwise_invert"] = np.bitwise_not + safe_numpy_globals["concat"] = np.concatenate + safe_numpy_globals["matrix_transpose"] = np.transpose + safe_numpy_globals["vecdot"] = blosc2.ndarray.npvecdot + def ne_evaluate(expression, local_dict=None, **kwargs): """Safely evaluate expressions using numexpr when possible, falling back to numpy.""" @@ -76,22 +102,24 @@ def ne_evaluate(expression, local_dict=None, **kwargs): ) } if blosc2.IS_WASM: - # Use numpy eval when running in WebAssembly - safe_globals = {"np": np} - # Add all first-level numpy functions - safe_globals.update( - { - name: getattr(np, name) - for name in dir(np) - if callable(getattr(np, name)) and not name.startswith("_") - } - ) + global safe_numpy_globals + if "out" in kwargs: + out = kwargs.pop("out") + out[:] = eval(expression, safe_numpy_globals, local_dict) + return out + return eval(expression, safe_numpy_globals, local_dict) + try: + return numexpr.evaluate(expression, local_dict=local_dict, **kwargs) + except ValueError as e: + raise e # unsafe expression + except Exception: # non_numexpr functions present + global safe_blosc2_globals + res = eval(expression, safe_blosc2_globals, local_dict) if "out" in kwargs: out = kwargs.pop("out") - out[:] = eval(expression, safe_globals, local_dict) + out[:] = res[()] if isinstance(res, blosc2.LazyArray) else res return out - return eval(expression, safe_globals, local_dict) - return numexpr.evaluate(expression, local_dict=local_dict, **kwargs) + return res[()] if isinstance(res, blosc2.LazyArray) else res # Define empty ndindex tuple for function defaults @@ -131,56 +159,116 @@ def ne_evaluate(expression, local_dict=None, **kwargs): "V": np.bytes_, } -functions = [ - "sin", - "cos", - "tan", - "sqrt", - "sinh", - "cosh", - "tanh", - "arcsin", +blosc2_funcs = [ + "abs", + "acos", + "acosh", + "add", + "all", + "any", + "arange", "arccos", + "arccosh", + "arcsin", + "arcsinh", "arctan", "arctan2", - "arcsinh", - "arccosh", "arctanh", - "exp", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_invert", + "bitwise_left_shift", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "broadcast_to", + "ceil", + "clip", + "concat", + "concatenate", + "copy", + "copysign", + "count_nonzero", + "divide", + "empty", + "empty_like", + "equal", + "expand_dims", "expm1", + "eye", + "floor", + "floor_divide", + "frombuffer", + "fromiter", + "full", + "full_like", + "greater", + "greater_equal", + "hypot", + "isfinite", + "isinf", + "isnan", + "less_equal", + "less_than", + "linspace", "log", - "log10", "log1p", "log2", - "conj", - "real", - "imag", - "contains", - "abs", - "sum", - "prod", - "mean", - "std", - "var", - "min", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "matmul", + "matrix_transpose", "max", - "any", - "all", - "pow" if np.__version__.startswith("2.") else "power", - "where", - "isnan", - "isfinite", - "isinf", - "nextafter", - "copysign", - "hypot", "maximum", + "mean", + "meshgrid", + "min", "minimum", - "floor", - "ceil", - "trunc", - "signbit", + "multiply", + "nans", + "ndarray_from_cframe", + "negative", + "nextafter", + "not_equal", + "ones", + "ones_like", + "permute_dims", + "positive", + "pow", + "prod", + "real", + "reciprocal", + "remainder", + "reshape", "round", + "sign", + "signbit", + "sort", + "square", + "squeeze", + "stack", + "sum", + "subtract", + "take", + "take_along_axis", + "tan", + "tanh", + "tensordot", + "transpose", + "trunc", + "var", + "vecdot", + "where", + "zeros", + "zeros_like", ] # Gather all callable functions in numpy @@ -192,10 +280,8 @@ def ne_evaluate(expression, local_dict=None, **kwargs): numpy_ufuncs = {name for name, member in inspect.getmembers(np, lambda x: isinstance(x, np.ufunc))} # Add these functions to the list of available functions # (will be evaluated via the array interface) -additional_funcs = sorted((numpy_funcs | numpy_ufuncs) - set(functions)) -functions += additional_funcs - -functions += constructors +additional_funcs = sorted((numpy_funcs | numpy_ufuncs) - set(blosc2_funcs)) +functions = blosc2_funcs + additional_funcs relational_ops = ["==", "!=", "<", "<=", ">", ">="] logical_ops = ["&", "|", "^", "~"] @@ -204,8 +290,7 @@ def ne_evaluate(expression, local_dict=None, **kwargs): def get_expr_globals(expression): """Build a dictionary of functions needed for evaluating the expression.""" - _globals = {} - + _globals = {"np": np, "blosc2": blosc2} # Only check for functions that actually appear in the expression # This avoids many unnecessary string searches for func in functions: @@ -2922,8 +3007,23 @@ def find_args(expr): return value, expression[idx:idx2] - def _compute_expr(self, item, kwargs): - if any(method in self.expression for method in reducers + not_numexpr_funcs): + def _compute_expr(self, item, kwargs): # noqa : C901 + # ne_evaluate will need safe_blosc2_globals for some functions (e.g. clip, logaddexp) + # that are implemenetd in python-blosc2 not in numexpr + global safe_blosc2_globals + if len(safe_blosc2_globals) == 0: + # First eval call, fill blosc2_safe_globals for ne_evaluate + safe_blosc2_globals = {"blosc2": blosc2} + # Add all first-level blosc2 functions + safe_blosc2_globals.update( + { + name: getattr(blosc2, name) + for name in dir(blosc2) + if callable(getattr(blosc2, name)) and not name.startswith("_") + } + ) + + if any(method in self.expression for method in reducers + lin_alg_funcs): # We have reductions in the expression (probably coming from a string lazyexpr) # Also includes slice _globals = get_expr_globals(self.expression) @@ -3450,7 +3550,7 @@ def _numpy_eval_expr(expression, operands, prefer_blosc=False): _globals = get_expr_globals(expression) _globals |= dtype_symbols else: - _globals = {f: getattr(np, f) for f in functions if f not in ("contains", "pow")} + _globals = safe_numpy_globals try: _out = eval(expression, _globals, ops) except RuntimeWarning: diff --git a/src/blosc2/linalg.py b/src/blosc2/linalg.py index 65cc2fb9..5b56a5c2 100644 --- a/src/blosc2/linalg.py +++ b/src/blosc2/linalg.py @@ -9,7 +9,7 @@ import numpy as np import blosc2 -from blosc2.ndarray import get_intersecting_chunks, slice_to_chunktuple +from blosc2.ndarray import get_intersecting_chunks, npvecdot, slice_to_chunktuple if TYPE_CHECKING: from collections.abc import Sequence @@ -340,7 +340,7 @@ def vecdot(x1: blosc2.NDArray, x2: blosc2.NDArray, axis: int = -1, **kwargs) -> fast_path = kwargs.pop("fast_path", None) # for testing purposes # Added this to pass array-api tests (which use internal getitem to check results) if isinstance(x1, np.ndarray) and isinstance(x2, np.ndarray): - return np.vecdot(x1, x2, axis=axis) + return npvecdot(x1, x2, axis=axis) x1, x2 = blosc2.asarray(x1), blosc2.asarray(x2) @@ -399,7 +399,7 @@ def vecdot(x1: blosc2.NDArray, x2: blosc2.NDArray, axis: int = -1, **kwargs) -> if fast_path: # just load everything, also handles case of 0 in shapes bx1 = x1[a_selection] bx2 = x2[b_selection] - result[res_chunk] += np.vecdot(bx1, bx2, axis=axis) # handles conjugation of bx1 + result[res_chunk] += npvecdot(bx1, bx2, axis=axis) # handles conjugation of bx1 else: # operands too big, have to go chunk-by-chunk for ochunk in range(0, a_shape_red, a_chunks_red): op_chunk = (slice(ochunk, builtins.min(ochunk + a_chunks_red, x1.shape[a_axes]), 1),) @@ -407,7 +407,7 @@ def vecdot(x1: blosc2.NDArray, x2: blosc2.NDArray, axis: int = -1, **kwargs) -> b_selection = b_selection[:b_axes] + op_chunk + b_selection[b_axes + 1 :] bx1 = x1[a_selection] bx2 = x2[b_selection] - res = np.vecdot(bx1, bx2, axis=axis) # handles conjugation of bx1 + res = npvecdot(bx1, bx2, axis=axis) # handles conjugation of bx1 result[res_chunk] += res return result diff --git a/src/blosc2/ndarray.py b/src/blosc2/ndarray.py index 4d47dfb4..44ce8012 100644 --- a/src/blosc2/ndarray.py +++ b/src/blosc2/ndarray.py @@ -41,11 +41,16 @@ nplshift = np.bitwise_left_shift nprshift = np.bitwise_right_shift npbinvert = np.bitwise_invert + npvecdot = np.vecdot else: # not array-api compliant nplshift = np.left_shift nprshift = np.right_shift npbinvert = np.bitwise_not + def npvecdot(a, b, axis=-1): + return np.einsum("...i,...i->...", np.moveaxis(np.conj(a), axis, -1), np.moveaxis(b, axis, -1)) + + # These functions in ufunc_map in ufunc_map_1param are implemented in numexpr and so we call # those instead (since numexpr uses multithreading it is faster) ufunc_map = { @@ -2932,6 +2937,7 @@ def clip( x: blosc2.Array, min: int | float | blosc2.Array | None = None, max: int | float | blosc2.Array | None = None, + **kwargs: Any, ) -> NDArray: """ Clamps each element x_i of the input array x to the range [min, max]. @@ -2949,6 +2955,9 @@ def clip( Upper-bound of the range to which to clamp. If None, no upper bound must be applied. Default: None. + kwargs: Any + kwargs accepted by the :func:`empty` constructor + Returns ------- out: NDArray @@ -2960,10 +2969,10 @@ def chunkwise_clip(inputs, output, offset): x, min, max = inputs output[:] = np.clip(x, min, max) - return blosc2.lazyudf(chunkwise_clip, (x, min, max), dtype=x.dtype, shape=x.shape) + return blosc2.lazyudf(chunkwise_clip, (x, min, max), dtype=x.dtype, shape=x.shape, **kwargs) -def logaddexp(x1: int | float | blosc2.Array, x2: int | float | blosc2.Array) -> NDArray: +def logaddexp(x1: int | float | blosc2.Array, x2: int | float | blosc2.Array, **kwargs: Any) -> NDArray: """ Calculates the logarithm of the sum of exponentiations log(exp(x1) + exp(x2)) for each element x1_i of the input array x1 with the respective element x2_i of the @@ -2974,10 +2983,13 @@ def logaddexp(x1: int | float | blosc2.Array, x2: int | float | blosc2.Array) -> x1: blosc2.Array First input array. May have any real-valued floating-point data type. - x2:blosc2.Array + x2: blosc2.Array Second input array. Must be compatible with x1. May have any real-valued floating-point data type. + kwargs: Any + kwargs accepted by the :func:`empty` constructor + Returns ------- out: NDArray @@ -2995,7 +3007,7 @@ def chunkwise_logaddexp(inputs, output, offset): if np.issubdtype(dtype, np.integer): dtype = blosc2.float32 - return blosc2.lazyudf(chunkwise_logaddexp, (x1, x2), dtype=dtype, shape=x1.shape) + return blosc2.lazyudf(chunkwise_logaddexp, (x1, x2), dtype=dtype, shape=x1.shape, **kwargs) # implemented in python-blosc2 diff --git a/src/blosc2/shape_utils.py b/src/blosc2/shape_utils.py index 3b7a4412..a1017e57 100644 --- a/src/blosc2/shape_utils.py +++ b/src/blosc2/shape_utils.py @@ -18,6 +18,8 @@ "vecdot", "T", "mT", + "take", + "take_along_axis", ) reducers = ("sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "slice", "count_nonzero") @@ -275,14 +277,32 @@ def visit_Attribute(self, node): return None def visit_Call(self, node): # noqa : C901 + # Extract full function name (support np.func, blosc2.func) func_name = getattr(node.func, "id", None) - attr_name = getattr(node.func, "attr", None) # handle methods called on funcs + attr_name = getattr(node.func, "attr", None) + module_name = getattr(getattr(node.func, "value", None), "id", None) + + # Handle namespaced calls like np.func or blosc2.func + if module_name in ("np", "blosc2"): + qualified_name = f"{module_name}.{attr_name}" + else: + qualified_name = attr_name or func_name + + base_name = qualified_name.split(".")[-1] # --- Recursive method-chain support --- obj_shape = None - if isinstance(node.func, ast.Attribute): + if isinstance(node.func, ast.Attribute) and module_name not in ( + "np", + "blosc2", + ): # check if genuine method and not module func obj_shape = self.visit(node.func.value) + args = [self.visit(arg) for arg in node.args] + # If it's a method call, prepend the object shape + if obj_shape is not None and attr_name == base_name: + args.insert(0, obj_shape) + # --- Parse keyword args --- kwargs = {} for kw in node.keywords: @@ -296,27 +316,18 @@ def visit_Call(self, node): # noqa : C901 kwargs[kw.arg] = self._lookup_value(kw.value) # ------- handle linear algebra --------------- - target = None - if func_name in lin_alg_funcs: - target = func_name - if attr_name in lin_alg_funcs: - target = attr_name - if target is not None: - args = [self.visit(arg) for arg in node.args] - # If it's a method call, prepend the object shape - if obj_shape is not None and attr_name == target: - args.insert(0, obj_shape) - return linalg_shape(target, args, kwargs) + if base_name in lin_alg_funcs: + return linalg_shape(base_name, args, kwargs) # ------- handle constructors --------------- - if func_name in constructors or attr_name == "reshape": + if base_name in constructors: # shape kwarg directly provided if "shape" in kwargs: val = kwargs["shape"] return val if isinstance(val, tuple) else (val,) # ---- array constructors like zeros, ones, full, etc. ---- - elif func_name in ( + elif base_name in ( "zeros", "ones", "empty", @@ -338,7 +349,7 @@ def visit_Call(self, node): # noqa : C901 return shape # ---- arange ---- - elif func_name == "arange": + elif base_name == "arange": start = self._lookup_value(node.args[0]) if node.args else 0 stop = self._lookup_value(node.args[1]) if len(node.args) > 1 else None step = self._lookup_value(node.args[2]) if len(node.args) > 2 else 1 @@ -358,7 +369,7 @@ def visit_Call(self, node): # noqa : C901 return (max(NUM, 0),) # ---- linspace ---- - elif func_name == "linspace": + elif base_name == "linspace": num = self._lookup_value(node.args[2]) if len(node.args) > 2 else kwargs.get("num") shape = self._lookup_value(node.args[5]) if len(node.args) > 5 else kwargs.get("shape") if shape is not None: @@ -367,15 +378,16 @@ def visit_Call(self, node): # noqa : C901 return (num,) raise ValueError("linspace requires either shape or num argument") - elif func_name == "frombuffer" or func_name == "fromiter": + elif base_name == "frombuffer" or base_name == "fromiter": count = kwargs.get("count") return (count,) if count else () - elif func_name == "eye": + + elif base_name == "eye": N = self._lookup_value(node.args[0]) M = self._lookup_value(node.args[1]) if len(node.args) > 1 else kwargs.get("M") return (N, N) if M is None else (N, M) - elif func_name == "reshape" or attr_name == "reshape": + elif base_name == "reshape": if node.args: shape_arg = node.args[-1] if isinstance(shape_arg, ast.Tuple): @@ -396,16 +408,10 @@ def visit_Call(self, node): # noqa : C901 slices = [self._eval_slice(slice_arg)] return slice_shape(obj_shape, slices) - # --- Evaluate argument shapes normally --- - args = [self.visit(arg) for arg in node.args] - - if func_name in FUNCTIONS: - return FUNCTIONS[func_name](*args, **kwargs) - if attr_name in FUNCTIONS: - return FUNCTIONS[attr_name](obj_shape, **kwargs) + if base_name in FUNCTIONS: + return FUNCTIONS[base_name](*args, **kwargs) - shapes = [obj_shape] + args if obj_shape is not None else args - shapes = [s for s in shapes if s is not None] + shapes = [s for s in args if s is not None] return elementwise(*shapes) if shapes else () def visit_Compare(self, node): diff --git a/tests/ndarray/test_lazyexpr.py b/tests/ndarray/test_lazyexpr.py index 779d664c..b4f8edd2 100644 --- a/tests/ndarray/test_lazyexpr.py +++ b/tests/ndarray/test_lazyexpr.py @@ -13,7 +13,7 @@ import blosc2 from blosc2.lazyexpr import ne_evaluate -from blosc2.ndarray import get_chunks_idx +from blosc2.ndarray import get_chunks_idx, npvecdot NITEMS_SMALL = 1_000 NITEMS = 10_000 @@ -1659,7 +1659,7 @@ def test_lazylinalg(): # --- matrix_transpose --- out = blosc2.lazyexpr("matrix_transpose(A)") - npres = np.matrix_transpose(npA) + npres = np.matrix_transpose(npA) if np.__version__.startswith("2.") else npA.T assert out.shape == npres.shape np.testing.assert_array_almost_equal(out[()], npres) out = blosc2.lazyexpr("C.mT") @@ -1708,7 +1708,7 @@ def test_lazylinalg(): # --- vecdot --- out = blosc2.lazyexpr("vecdot(x, y)") - npres = np.vecdot(npx, npy) + npres = npvecdot(npx, npy) assert out.shape == npres.shape np.testing.assert_array_almost_equal(out[()], npres) diff --git a/tests/ndarray/test_lazyudf.py b/tests/ndarray/test_lazyudf.py index 71e0dd05..be53e175 100644 --- a/tests/ndarray/test_lazyudf.py +++ b/tests/ndarray/test_lazyudf.py @@ -423,6 +423,10 @@ def test_clip_logaddexp(shape, chunks, blocks, slices): # clip is not a ufunc so will return np.ndarray expr = np.clip(b, np.prod(shape) // 3, npb - 10) assert isinstance(expr, np.ndarray) + # test lazyexpr interface + expr = blosc2.lazyexpr("clip(b, np.prod(shape) // 3, npb - 10)") + res = expr.compute(item=slices) + np.testing.assert_allclose(res[...], npc[slices]) npc = np.logaddexp(npb, npa) expr = blosc2.logaddexp(b, a) @@ -432,3 +436,8 @@ def test_clip_logaddexp(shape, chunks, blocks, slices): # (i.e. doesn't return np.ndarray) expr = np.logaddexp(b, a) assert isinstance(expr, blosc2.LazyArray) + + # test lazyexpr interface + expr = blosc2.lazyexpr("logaddexp(a, b)") + res = expr.compute(item=slices) + np.testing.assert_allclose(res[...], npc[slices]) diff --git a/tests/ndarray/test_linalg.py b/tests/ndarray/test_linalg.py index 479eb8c8..fd6ecb0f 100644 --- a/tests/ndarray/test_linalg.py +++ b/tests/ndarray/test_linalg.py @@ -4,6 +4,7 @@ import pytest import blosc2 +from blosc2.ndarray import npvecdot @pytest.mark.parametrize( @@ -501,11 +502,15 @@ def test_outer(shape1, chunk1, block1, shape2, chunk2, block2, chunkres, dtype): np.int64, np.float32, np.float64, + np.complex128, ], ) def test_vecdot(shape1, chunk1, block1, shape2, chunk2, block2, chunkres, axis, dtype): # Create operands with requested dtype a_b2 = blosc2.arange(0, np.prod(shape1), shape=shape1, chunks=chunk1, blocks=block1, dtype=dtype) + if dtype == np.complex128: + a_b2 += 1j + a_b2 = a_b2.compute() a_np = a_b2[()] # decompress b_b2 = blosc2.arange(0, np.prod(shape2), shape=shape2, chunks=chunk2, blocks=block2, dtype=dtype) b_np = b_b2[()] # decompress @@ -513,7 +518,7 @@ def test_vecdot(shape1, chunk1, block1, shape2, chunk2, block2, chunkres, axis, # NumPy reference and Blosc2 comparison np_raised = None try: - res_np = np.vecdot(a_np, b_np, axis=axis) + res_np = npvecdot(a_np, b_np, axis=axis) except Exception as e: np_raised = type(e) @@ -523,7 +528,7 @@ def test_vecdot(shape1, chunk1, block1, shape2, chunk2, block2, chunkres, axis, blosc2.vecdot(a_b2, b_b2, axis=axis, chunks=chunkres) else: # Both should succeed - res_np = np.vecdot(a_np, b_np, axis=axis) + res_np = npvecdot(a_np, b_np, axis=axis) res_b2 = blosc2.vecdot(a_b2, b_b2, axis=axis, chunks=chunkres, fast_path=False) # test slow path res_b2_np = res_b2[...] From 580e3f51db6c5292cc7d394374430cea5a67e9dd Mon Sep 17 00:00:00 2001 From: lshaw8317 Date: Tue, 7 Oct 2025 13:16:59 +0200 Subject: [PATCH 08/10] Fix arctan2 issue --- src/blosc2/ndarray.py | 2 +- tests/ndarray/test_lazyexpr.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/blosc2/ndarray.py b/src/blosc2/ndarray.py index 44ce8012..e8f991e3 100644 --- a/src/blosc2/ndarray.py +++ b/src/blosc2/ndarray.py @@ -3076,7 +3076,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): if ufunc in ufunc_map: value = inputs[0] if inputs[1] is self else inputs[1] _check_allowed_dtypes(value) - return blosc2.LazyExpr(new_op=(value, ufunc_map[ufunc], self)) + return blosc2.LazyExpr(new_op=(inputs[0], ufunc_map[ufunc], inputs[1])) if ufunc in ufunc_map_1param: value = inputs[0] diff --git a/tests/ndarray/test_lazyexpr.py b/tests/ndarray/test_lazyexpr.py index b4f8edd2..15eae1d8 100644 --- a/tests/ndarray/test_lazyexpr.py +++ b/tests/ndarray/test_lazyexpr.py @@ -1550,11 +1550,10 @@ def test_complex_lazy_expression_multiplication(): theta_np = np.arctan2(Y_b2[:], X_b2[:]) expected = np.sin(R_np * 4 - time_factor * 2) * np.cos(theta_np * 6) - # TODO: for some reason, the result is negative, so we assert against -expected - np.testing.assert_allclose(result, -expected, rtol=1e-14, atol=1e-14) + np.testing.assert_allclose(result, expected, rtol=1e-14, atol=1e-14) # Also test getitem access - np.testing.assert_allclose(result_expr[:], -expected, rtol=1e-14, atol=1e-14) + np.testing.assert_allclose(result_expr[:], expected, rtol=1e-14, atol=1e-14) # Test checking that objects following the blosc2.Array protocol can be operated with From b461889996d39d40f1b824f97a2574766d401f1b Mon Sep 17 00:00:00 2001 From: lshaw8317 Date: Tue, 7 Oct 2025 18:51:01 +0200 Subject: [PATCH 09/10] Improve handling of known/unknown funcs --- src/blosc2/lazyexpr.py | 135 ++----------------------------- src/blosc2/shape_utils.py | 101 ++++++++++++++++++++--- tests/ndarray/test_reductions.py | 6 +- 3 files changed, 100 insertions(+), 142 deletions(-) diff --git a/src/blosc2/lazyexpr.py b/src/blosc2/lazyexpr.py index f832c77c..cf413f42 100644 --- a/src/blosc2/lazyexpr.py +++ b/src/blosc2/lazyexpr.py @@ -45,14 +45,13 @@ _check_allowed_dtypes, get_chunks_idx, get_intersecting_chunks, - is_inside_new_expr, local_ufunc_map, process_key, ufunc_map, ufunc_map_1param, ) -from .shape_utils import constructors, infer_shape, lin_alg_funcs, reducers +from .shape_utils import constructors, elementwise_funcs, infer_shape, lin_alg_attrs, lin_alg_funcs, reducers if not blosc2.IS_WASM: import numexpr @@ -158,119 +157,9 @@ def ne_evaluate(expression, local_dict=None, **kwargs): "S": np.str_, "V": np.bytes_, } - -blosc2_funcs = [ - "abs", - "acos", - "acosh", - "add", - "all", - "any", - "arange", - "arccos", - "arccosh", - "arcsin", - "arcsinh", - "arctan", - "arctan2", - "arctanh", - "asin", - "asinh", - "atan", - "atan2", - "atanh", - "bitwise_and", - "bitwise_invert", - "bitwise_left_shift", - "bitwise_or", - "bitwise_right_shift", - "bitwise_xor", - "broadcast_to", - "ceil", - "clip", - "concat", - "concatenate", - "copy", - "copysign", - "count_nonzero", - "divide", - "empty", - "empty_like", - "equal", - "expand_dims", - "expm1", - "eye", - "floor", - "floor_divide", - "frombuffer", - "fromiter", - "full", - "full_like", - "greater", - "greater_equal", - "hypot", - "isfinite", - "isinf", - "isnan", - "less_equal", - "less_than", - "linspace", - "log", - "log1p", - "log2", - "log10", - "logaddexp", - "logical_and", - "logical_not", - "logical_or", - "logical_xor", - "matmul", - "matrix_transpose", - "max", - "maximum", - "mean", - "meshgrid", - "min", - "minimum", - "multiply", - "nans", - "ndarray_from_cframe", - "negative", - "nextafter", - "not_equal", - "ones", - "ones_like", - "permute_dims", - "positive", - "pow", - "prod", - "real", - "reciprocal", - "remainder", - "reshape", - "round", - "sign", - "signbit", - "sort", - "square", - "squeeze", - "stack", - "sum", - "subtract", - "take", - "take_along_axis", - "tan", - "tanh", - "tensordot", - "transpose", - "trunc", - "var", - "vecdot", - "where", - "zeros", - "zeros_like", -] - +blosc2_funcs = constructors + lin_alg_funcs + elementwise_funcs + reducers +# functions that have to be evaluated before chunkwise lazyexpr machinery +eager_funcs = lin_alg_funcs + reducers + ["slice"] + lin_alg_attrs # Gather all callable functions in numpy numpy_funcs = { name @@ -751,12 +640,7 @@ def validate_expr(expr: str) -> None: def extract_and_replace_slices(expr, operands): """ - Replaces all var.slice(...).slice(...) chains in expr with oN temporary variables. - Infers shapes using infer_shape and creates placeholder arrays in new_ops. - - Returns: - new_expr: expression string with oN replacements - new_ops: dictionary mapping variable names (original and oN) to arrays + Return new expression and operands with op.slice(...) replaced by temporary operands. """ # Copy shapes and operands shapes = {k: () if not hasattr(v, "shape") else v.shape for k, v in operands.items()} @@ -1976,11 +1860,6 @@ def reduce_slices( # noqa: C901 if out is not None and reduced_shape != out.shape: raise ValueError("Provided output shape does not match the reduced shape.") - if is_inside_new_expr(): - # We already have the dtype and reduced_shape, so return immediately - # Use a blosc2 container, as it consumes less memory in general - return blosc2.zeros(reduced_shape, dtype=dtype) - # Choose the array with the largest shape as the reference for chunks # Note: we could have expr = blosc2.lazyexpr('numpy_array + 1') (i.e. no choice for chunks) blosc2_arrs = tuple(o for o in operands.values() if hasattr(o, "chunks")) @@ -3023,7 +2902,7 @@ def _compute_expr(self, item, kwargs): # noqa : C901 } ) - if any(method in self.expression for method in reducers + lin_alg_funcs): + if any(method in self.expression for method in eager_funcs): # We have reductions in the expression (probably coming from a string lazyexpr) # Also includes slice _globals = get_expr_globals(self.expression) @@ -3073,7 +2952,7 @@ def _compute_expr(self, item, kwargs): # noqa : C901 # Replace the constructor call by the new operand newexpr = newexpr.replace(constexpr, newop) - _globals = {func: getattr(blosc2, func) for func in functions if func in newexpr} + _globals = get_expr_globals(newexpr) lazy_expr = eval(newexpr, _globals, newops) if isinstance(lazy_expr, blosc2.NDArray): # Almost done (probably the expression is made of only constructors) diff --git a/src/blosc2/shape_utils.py b/src/blosc2/shape_utils.py index a1017e57..6b738cd7 100644 --- a/src/blosc2/shape_utils.py +++ b/src/blosc2/shape_utils.py @@ -1,9 +1,85 @@ import ast import builtins +import warnings from numpy import broadcast_shapes -lin_alg_funcs = ( +elementwise_funcs = [ + "abs", + "acos", + "acosh", + "add", + "arccos", + "arccosh", + "arcsin", + "arcsinh", + "arctan", + "arctan2", + "arctanh", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_invert", + "bitwise_left_shift", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "broadcast_to", + "ceil", + "clip", + "copysign", + "divide", + "equal", + "expm1", + "floor", + "floor_divide", + "greater", + "greater_equal", + "hypot", + "isfinite", + "isinf", + "isnan", + "less_equal", + "less_than", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "maximum", + "minimum", + "multiply", + "negative", + "nextafter", + "not_equal", + "positive", + "pow", + "real", + "reciprocal", + "remainder", + "round", + "sign", + "signbit", + "square", + "sum", + "subtract", + "tan", + "tanh", + "trunc", + "var", + "where", + "zeros", + "zeros_like", +] + +lin_alg_funcs = [ "concat", "diagonal", "expand_dims", @@ -16,15 +92,13 @@ "tensordot", "transpose", "vecdot", - "T", - "mT", - "take", - "take_along_axis", -) -reducers = ("sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "slice", "count_nonzero") +] + +lin_alg_attrs = ["T", "mT"] +reducers = ["sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "count_nonzero"] # All the available constructors and reducers necessary for the (string) expression evaluator -constructors = ( +constructors = [ "arange", "linspace", "fromiter", @@ -38,9 +112,10 @@ "ones_like", "empty_like", "eye", -) + "nans", +] # Note that, as reshape is accepted as a method too, it should always come last in the list -constructors += ("reshape",) +constructors += ["reshape"] # --- Shape utilities --- @@ -336,6 +411,7 @@ def visit_Call(self, node): # noqa : C901 "zeros_like", "empty_like", "ones_like", + "nans", ): if node.args: shape_arg = node.args[0] @@ -412,6 +488,11 @@ def visit_Call(self, node): # noqa : C901 return FUNCTIONS[base_name](*args, **kwargs) shapes = [s for s in args if s is not None] + if base_name not in elementwise_funcs: + warnings.warn( + f"Function shape parser not implemented for {base_name}.", UserWarning, stacklevel=2 + ) + # default to elementwise but print warning that function not defined explicitly return elementwise(*shapes) if shapes else () def visit_Compare(self, node): diff --git a/tests/ndarray/test_reductions.py b/tests/ndarray/test_reductions.py index dbaa7997..ebc43e3f 100644 --- a/tests/ndarray/test_reductions.py +++ b/tests/ndarray/test_reductions.py @@ -459,12 +459,10 @@ def test_reduction_index(): assert arr.shape == newarr.shape a = blosc2.ones(shape=(0, 0)) - arr = blosc2.lazyexpr("sum(a, axis=(0, 1, 2))", {"a": a}) with pytest.raises(np.exceptions.AxisError): - newarr = arr.compute() - arr = blosc2.lazyexpr("sum(a, axis=(0, 0))", {"a": a}) + arr = blosc2.lazyexpr("sum(a, axis=(0, 1, 2))", {"a": a}) with pytest.raises(ValueError): - newarr = arr.compute() + arr = blosc2.lazyexpr("sum(a, axis=(0, 0))", {"a": a}) @pytest.mark.parametrize("idx", [0, 1, (0,), slice(1, 2), (slice(0, 1),), slice(0, 4), (0, 2)]) From ebdd8efc0e80e609ad4c1488954b5fe5cc455d9e Mon Sep 17 00:00:00 2001 From: lshaw8317 Date: Tue, 7 Oct 2025 19:48:18 +0200 Subject: [PATCH 10/10] Cleanup, add documentation --- ADD_LAZYFUNCS.md | 17 +++++++++++++++++ src/blosc2/lazyexpr.py | 18 +++++++----------- src/blosc2/shape_utils.py | 12 ++++++------ 3 files changed, 30 insertions(+), 17 deletions(-) create mode 100644 ADD_LAZYFUNCS.md diff --git a/ADD_LAZYFUNCS.md b/ADD_LAZYFUNCS.md new file mode 100644 index 00000000..4807d5dc --- /dev/null +++ b/ADD_LAZYFUNCS.md @@ -0,0 +1,17 @@ +# Adding (lazy) functions + +Once you have written a (public API) function in Blosc2, it is important to: +* Import it from the relevant module in the ``__init__.py`` file +* Add it to the list of functions in ``__all__`` in the ``__init__.py`` file +* If it is present in numpy, add it to the relevant dictionary (``local_ufunc_map``, ``ufunc_map`` ``ufunc_map_1param``) in ``ndarray.py`` + +Finally, you also need to deal with it correctly within ``shape_utils.py``. + +If the function does not change the shape of the output, simply add it to ``elementwise_funcs`` and you're done. + +If the function _does_ change the shape of the output, it is likely either a reduction, a constructor, or a linear algebra function and so should be added to one of those lists (``reducers``, ``constructor`` or ``linalg_funcs``). If the function is a reduction, unless you need to handle an argument that is neither ``axis`` nor ``keepdims``, you don't need to do anything else. +If your function is a constructor, you need to ensure it is handled within the ``visit_Call`` function appropriately (if it has a shape argument this is easy, just add it to the list of functions that has ``zeros, zeros_like`` etc.). + +For linear algebra functions it is likely you will have to write a bespoke shape handler within the ``linalg_shape`` function. There is also a list ``linalg_attrs`` for attributes which change the shape (currently only ``T`` and ``mT``) should you need to add one. You will probably need to edit the ``validation_patterns`` list at the top of the ``lazyexpr.py`` file to handle these attributes. Just extend the part that has the negative lookahead "(?!real|imag|T|mT|(". + +After this, the imports at the top of the ``lazyexpr.py`` should handle things, where an ``eager_funcs`` list is defined to handle eager execution of functions which change the output shape. Finally, in order to handle name changes between NumPy versions 1 and 2, it may be necessary to add aliases for functions within the blocks defined by ``if NUMPY_GE_2_0:`` in ``lazyexpr.py`` and ``ndarray.py``. diff --git a/src/blosc2/lazyexpr.py b/src/blosc2/lazyexpr.py index cf413f42..b83998f7 100644 --- a/src/blosc2/lazyexpr.py +++ b/src/blosc2/lazyexpr.py @@ -51,7 +51,7 @@ ufunc_map_1param, ) -from .shape_utils import constructors, elementwise_funcs, infer_shape, lin_alg_attrs, lin_alg_funcs, reducers +from .shape_utils import constructors, elementwise_funcs, infer_shape, linalg_attrs, linalg_funcs, reducers if not blosc2.IS_WASM: import numexpr @@ -157,9 +157,9 @@ def ne_evaluate(expression, local_dict=None, **kwargs): "S": np.str_, "V": np.bytes_, } -blosc2_funcs = constructors + lin_alg_funcs + elementwise_funcs + reducers +blosc2_funcs = constructors + linalg_funcs + elementwise_funcs + reducers # functions that have to be evaluated before chunkwise lazyexpr machinery -eager_funcs = lin_alg_funcs + reducers + ["slice"] + lin_alg_attrs +eager_funcs = linalg_funcs + reducers + ["slice"] + ["." + attr for attr in linalg_attrs] # Gather all callable functions in numpy numpy_funcs = { name @@ -569,7 +569,7 @@ def compute_smaller_slice(larger_shape, smaller_shape, larger_slice): validation_patterns = [ r"[\;]", # Flow control characters r"(^|[^\w])__[\w]+__($|[^\w])", # Dunder methods - r"\.\b(?!real|imag|(\d*[eE]?[+-]?\d+)|(\d*[eE]?[+-]?\d+j)|\d*j\b|(sum|prod|min|max|std|mean|var|any|all|where)" + r"\.\b(?!real|imag|T|mT|(\d*[eE]?[+-]?\d+)|(\d*[eE]?[+-]?\d+j)|\d*j\b|(sum|prod|min|max|std|mean|var|any|all|where)" r"\s*\([^)]*\)|[a-zA-Z_]\w*\s*\([^)]*\))", # Attribute patterns ] @@ -595,10 +595,8 @@ def compute_smaller_slice(larger_shape, smaller_shape, larger_slice): valid_methods |= {"float32", "float64", "complex64", "complex128"} valid_methods |= {"bool", "str", "bytes"} valid_methods |= { - name - for name in dir(blosc2.NDArray) - if callable(getattr(blosc2.NDArray, name)) and not name.startswith("_") -} + name for name in dir(blosc2.NDArray) if not name.startswith("_") +} # allow attributes and methods def validate_expr(expr: str) -> None: @@ -621,9 +619,7 @@ def validate_expr(expr: str) -> None: # Check for forbidden patterns forbiddens = _blacklist_re.search(skip_quotes) if forbiddens is not None: - i = forbiddens.span()[0] - if expr[i : i + 2] != ".T" and expr[i : i + 3] != ".mT": # allow tranpose methods - raise ValueError(f"'{expr}' is not a valid expression.") + raise ValueError(f"'{expr}' is not a valid expression.") # Check for invalid characters not covered by the tokenizer invalid_chars = re.compile(r"[^\w\s+\-*/%()[].,=<>!&|~^]") diff --git a/src/blosc2/shape_utils.py b/src/blosc2/shape_utils.py index 6b738cd7..32c4a75b 100644 --- a/src/blosc2/shape_utils.py +++ b/src/blosc2/shape_utils.py @@ -79,7 +79,7 @@ "zeros_like", ] -lin_alg_funcs = [ +linalg_funcs = [ "concat", "diagonal", "expand_dims", @@ -94,7 +94,7 @@ "vecdot", ] -lin_alg_attrs = ["T", "mT"] +linalg_attrs = ["T", "mT"] reducers = ["sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "count_nonzero"] # All the available constructors and reducers necessary for the (string) expression evaluator @@ -317,7 +317,7 @@ def elementwise(*args): # --- Function registry --- -FUNCTIONS = { # ignore out arg +REDUCTIONS = { # ignore out arg func: lambda x, axis=None, keepdims=False, out=None: reduce_shape(x, axis, keepdims) for func in reducers # any unknown function will default to elementwise @@ -391,7 +391,7 @@ def visit_Call(self, node): # noqa : C901 kwargs[kw.arg] = self._lookup_value(kw.value) # ------- handle linear algebra --------------- - if base_name in lin_alg_funcs: + if base_name in linalg_funcs: return linalg_shape(base_name, args, kwargs) # ------- handle constructors --------------- @@ -484,8 +484,8 @@ def visit_Call(self, node): # noqa : C901 slices = [self._eval_slice(slice_arg)] return slice_shape(obj_shape, slices) - if base_name in FUNCTIONS: - return FUNCTIONS[base_name](*args, **kwargs) + if base_name in REDUCTIONS: + return REDUCTIONS[base_name](*args, **kwargs) shapes = [s for s in args if s is not None] if base_name not in elementwise_funcs: