Skip to content

Commit b461889

Browse files
committed
Improve handling of known/unknown funcs
1 parent 580e3f5 commit b461889

File tree

3 files changed

+100
-142
lines changed

3 files changed

+100
-142
lines changed

src/blosc2/lazyexpr.py

Lines changed: 7 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,13 @@
4545
_check_allowed_dtypes,
4646
get_chunks_idx,
4747
get_intersecting_chunks,
48-
is_inside_new_expr,
4948
local_ufunc_map,
5049
process_key,
5150
ufunc_map,
5251
ufunc_map_1param,
5352
)
5453

55-
from .shape_utils import constructors, infer_shape, lin_alg_funcs, reducers
54+
from .shape_utils import constructors, elementwise_funcs, infer_shape, lin_alg_attrs, lin_alg_funcs, reducers
5655

5756
if not blosc2.IS_WASM:
5857
import numexpr
@@ -158,119 +157,9 @@ def ne_evaluate(expression, local_dict=None, **kwargs):
158157
"S": np.str_,
159158
"V": np.bytes_,
160159
}
161-
162-
blosc2_funcs = [
163-
"abs",
164-
"acos",
165-
"acosh",
166-
"add",
167-
"all",
168-
"any",
169-
"arange",
170-
"arccos",
171-
"arccosh",
172-
"arcsin",
173-
"arcsinh",
174-
"arctan",
175-
"arctan2",
176-
"arctanh",
177-
"asin",
178-
"asinh",
179-
"atan",
180-
"atan2",
181-
"atanh",
182-
"bitwise_and",
183-
"bitwise_invert",
184-
"bitwise_left_shift",
185-
"bitwise_or",
186-
"bitwise_right_shift",
187-
"bitwise_xor",
188-
"broadcast_to",
189-
"ceil",
190-
"clip",
191-
"concat",
192-
"concatenate",
193-
"copy",
194-
"copysign",
195-
"count_nonzero",
196-
"divide",
197-
"empty",
198-
"empty_like",
199-
"equal",
200-
"expand_dims",
201-
"expm1",
202-
"eye",
203-
"floor",
204-
"floor_divide",
205-
"frombuffer",
206-
"fromiter",
207-
"full",
208-
"full_like",
209-
"greater",
210-
"greater_equal",
211-
"hypot",
212-
"isfinite",
213-
"isinf",
214-
"isnan",
215-
"less_equal",
216-
"less_than",
217-
"linspace",
218-
"log",
219-
"log1p",
220-
"log2",
221-
"log10",
222-
"logaddexp",
223-
"logical_and",
224-
"logical_not",
225-
"logical_or",
226-
"logical_xor",
227-
"matmul",
228-
"matrix_transpose",
229-
"max",
230-
"maximum",
231-
"mean",
232-
"meshgrid",
233-
"min",
234-
"minimum",
235-
"multiply",
236-
"nans",
237-
"ndarray_from_cframe",
238-
"negative",
239-
"nextafter",
240-
"not_equal",
241-
"ones",
242-
"ones_like",
243-
"permute_dims",
244-
"positive",
245-
"pow",
246-
"prod",
247-
"real",
248-
"reciprocal",
249-
"remainder",
250-
"reshape",
251-
"round",
252-
"sign",
253-
"signbit",
254-
"sort",
255-
"square",
256-
"squeeze",
257-
"stack",
258-
"sum",
259-
"subtract",
260-
"take",
261-
"take_along_axis",
262-
"tan",
263-
"tanh",
264-
"tensordot",
265-
"transpose",
266-
"trunc",
267-
"var",
268-
"vecdot",
269-
"where",
270-
"zeros",
271-
"zeros_like",
272-
]
273-
160+
blosc2_funcs = constructors + lin_alg_funcs + elementwise_funcs + reducers
161+
# functions that have to be evaluated before chunkwise lazyexpr machinery
162+
eager_funcs = lin_alg_funcs + reducers + ["slice"] + lin_alg_attrs
274163
# Gather all callable functions in numpy
275164
numpy_funcs = {
276165
name
@@ -751,12 +640,7 @@ def validate_expr(expr: str) -> None:
751640

752641
def extract_and_replace_slices(expr, operands):
753642
"""
754-
Replaces all var.slice(...).slice(...) chains in expr with oN temporary variables.
755-
Infers shapes using infer_shape and creates placeholder arrays in new_ops.
756-
757-
Returns:
758-
new_expr: expression string with oN replacements
759-
new_ops: dictionary mapping variable names (original and oN) to arrays
643+
Return new expression and operands with op.slice(...) replaced by temporary operands.
760644
"""
761645
# Copy shapes and operands
762646
shapes = {k: () if not hasattr(v, "shape") else v.shape for k, v in operands.items()}
@@ -1976,11 +1860,6 @@ def reduce_slices( # noqa: C901
19761860
if out is not None and reduced_shape != out.shape:
19771861
raise ValueError("Provided output shape does not match the reduced shape.")
19781862

1979-
if is_inside_new_expr():
1980-
# We already have the dtype and reduced_shape, so return immediately
1981-
# Use a blosc2 container, as it consumes less memory in general
1982-
return blosc2.zeros(reduced_shape, dtype=dtype)
1983-
19841863
# Choose the array with the largest shape as the reference for chunks
19851864
# Note: we could have expr = blosc2.lazyexpr('numpy_array + 1') (i.e. no choice for chunks)
19861865
blosc2_arrs = tuple(o for o in operands.values() if hasattr(o, "chunks"))
@@ -3023,7 +2902,7 @@ def _compute_expr(self, item, kwargs): # noqa : C901
30232902
}
30242903
)
30252904

3026-
if any(method in self.expression for method in reducers + lin_alg_funcs):
2905+
if any(method in self.expression for method in eager_funcs):
30272906
# We have reductions in the expression (probably coming from a string lazyexpr)
30282907
# Also includes slice
30292908
_globals = get_expr_globals(self.expression)
@@ -3073,7 +2952,7 @@ def _compute_expr(self, item, kwargs): # noqa : C901
30732952
# Replace the constructor call by the new operand
30742953
newexpr = newexpr.replace(constexpr, newop)
30752954

3076-
_globals = {func: getattr(blosc2, func) for func in functions if func in newexpr}
2955+
_globals = get_expr_globals(newexpr)
30772956
lazy_expr = eval(newexpr, _globals, newops)
30782957
if isinstance(lazy_expr, blosc2.NDArray):
30792958
# Almost done (probably the expression is made of only constructors)

src/blosc2/shape_utils.py

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,85 @@
11
import ast
22
import builtins
3+
import warnings
34

45
from numpy import broadcast_shapes
56

6-
lin_alg_funcs = (
7+
elementwise_funcs = [
8+
"abs",
9+
"acos",
10+
"acosh",
11+
"add",
12+
"arccos",
13+
"arccosh",
14+
"arcsin",
15+
"arcsinh",
16+
"arctan",
17+
"arctan2",
18+
"arctanh",
19+
"asin",
20+
"asinh",
21+
"atan",
22+
"atan2",
23+
"atanh",
24+
"bitwise_and",
25+
"bitwise_invert",
26+
"bitwise_left_shift",
27+
"bitwise_or",
28+
"bitwise_right_shift",
29+
"bitwise_xor",
30+
"broadcast_to",
31+
"ceil",
32+
"clip",
33+
"copysign",
34+
"divide",
35+
"equal",
36+
"expm1",
37+
"floor",
38+
"floor_divide",
39+
"greater",
40+
"greater_equal",
41+
"hypot",
42+
"isfinite",
43+
"isinf",
44+
"isnan",
45+
"less_equal",
46+
"less_than",
47+
"log",
48+
"log1p",
49+
"log2",
50+
"log10",
51+
"logaddexp",
52+
"logical_and",
53+
"logical_not",
54+
"logical_or",
55+
"logical_xor",
56+
"maximum",
57+
"minimum",
58+
"multiply",
59+
"negative",
60+
"nextafter",
61+
"not_equal",
62+
"positive",
63+
"pow",
64+
"real",
65+
"reciprocal",
66+
"remainder",
67+
"round",
68+
"sign",
69+
"signbit",
70+
"square",
71+
"sum",
72+
"subtract",
73+
"tan",
74+
"tanh",
75+
"trunc",
76+
"var",
77+
"where",
78+
"zeros",
79+
"zeros_like",
80+
]
81+
82+
lin_alg_funcs = [
783
"concat",
884
"diagonal",
985
"expand_dims",
@@ -16,15 +92,13 @@
1692
"tensordot",
1793
"transpose",
1894
"vecdot",
19-
"T",
20-
"mT",
21-
"take",
22-
"take_along_axis",
23-
)
24-
reducers = ("sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "slice", "count_nonzero")
95+
]
96+
97+
lin_alg_attrs = ["T", "mT"]
98+
reducers = ["sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "count_nonzero"]
2599

26100
# All the available constructors and reducers necessary for the (string) expression evaluator
27-
constructors = (
101+
constructors = [
28102
"arange",
29103
"linspace",
30104
"fromiter",
@@ -38,9 +112,10 @@
38112
"ones_like",
39113
"empty_like",
40114
"eye",
41-
)
115+
"nans",
116+
]
42117
# Note that, as reshape is accepted as a method too, it should always come last in the list
43-
constructors += ("reshape",)
118+
constructors += ["reshape"]
44119

45120

46121
# --- Shape utilities ---
@@ -336,6 +411,7 @@ def visit_Call(self, node): # noqa : C901
336411
"zeros_like",
337412
"empty_like",
338413
"ones_like",
414+
"nans",
339415
):
340416
if node.args:
341417
shape_arg = node.args[0]
@@ -412,6 +488,11 @@ def visit_Call(self, node): # noqa : C901
412488
return FUNCTIONS[base_name](*args, **kwargs)
413489

414490
shapes = [s for s in args if s is not None]
491+
if base_name not in elementwise_funcs:
492+
warnings.warn(
493+
f"Function shape parser not implemented for {base_name}.", UserWarning, stacklevel=2
494+
)
495+
# default to elementwise but print warning that function not defined explicitly
415496
return elementwise(*shapes) if shapes else ()
416497

417498
def visit_Compare(self, node):

tests/ndarray/test_reductions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -459,12 +459,10 @@ def test_reduction_index():
459459
assert arr.shape == newarr.shape
460460

461461
a = blosc2.ones(shape=(0, 0))
462-
arr = blosc2.lazyexpr("sum(a, axis=(0, 1, 2))", {"a": a})
463462
with pytest.raises(np.exceptions.AxisError):
464-
newarr = arr.compute()
465-
arr = blosc2.lazyexpr("sum(a, axis=(0, 0))", {"a": a})
463+
arr = blosc2.lazyexpr("sum(a, axis=(0, 1, 2))", {"a": a})
466464
with pytest.raises(ValueError):
467-
newarr = arr.compute()
465+
arr = blosc2.lazyexpr("sum(a, axis=(0, 0))", {"a": a})
468466

469467

470468
@pytest.mark.parametrize("idx", [0, 1, (0,), slice(1, 2), (slice(0, 1),), slice(0, 4), (0, 2)])

0 commit comments

Comments
 (0)