Skip to content

Commit 71967b6

Browse files
committed
Enable lazy imperative matmul etc.
1 parent 7424ce8 commit 71967b6

File tree

3 files changed

+61
-29
lines changed

3 files changed

+61
-29
lines changed

src/blosc2/lazyexpr.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@
5151
ufunc_map_1param,
5252
)
5353

54-
from .shape_utils import constructors, infer_shape, reducers
54+
from .shape_utils import constructors, infer_shape, lin_alg_funcs, reducers
55+
56+
lin_alg_funcs += ("clip", "logaddexp")
5557

5658
if not blosc2.IS_WASM:
5759
import numexpr
@@ -2913,13 +2915,13 @@ def find_args(expr):
29132915
return value, expression[idx:idx2]
29142916

29152917
def _compute_expr(self, item, kwargs):
2916-
if any(method in self.expression for method in reducers):
2918+
if any(method in self.expression for method in reducers + lin_alg_funcs):
29172919
# We have reductions in the expression (probably coming from a string lazyexpr)
29182920
# Also includes slice
29192921
_globals = get_expr_globals(self.expression)
29202922
lazy_expr = eval(self.expression, _globals, self.operands)
29212923
if not isinstance(lazy_expr, blosc2.LazyExpr):
2922-
key, mask = process_key(item, self.shape)
2924+
key, mask = process_key(item, lazy_expr.shape)
29232925
# An immediate evaluation happened (e.g. all operands are numpy arrays)
29242926
if hasattr(self, "_where_args"):
29252927
# We need to apply the where() operation

src/blosc2/shape_utils.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,21 @@
22

33
from numpy import broadcast_shapes
44

5-
reducers = ("sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "slice")
5+
lin_alg_funcs = (
6+
"concat",
7+
"diagonal",
8+
"expand_dims",
9+
"matmul",
10+
"matrix_transpose",
11+
"outer",
12+
"permute_dims",
13+
"squeeze",
14+
"stack",
15+
"tensordot",
16+
"transpose",
17+
"vecdot",
18+
)
19+
reducers = ("sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "slice", "count_nonzero")
620

721
# All the available constructors and reducers necessary for the (string) expression evaluator
822
constructors = (
@@ -18,6 +32,7 @@
1832
"zeros_like",
1933
"ones_like",
2034
"empty_like",
35+
"eye",
2136
)
2237
# Note that, as reshape is accepted as a method too, it should always come last in the list
2338
constructors += ("reshape",)
@@ -50,6 +65,8 @@ def reduce_shape(shape, axis, keepdims):
5065

5166
def slice_shape(shape, slices):
5267
"""Infer shape after slicing."""
68+
if shape is None:
69+
return None
5370
result = []
5471
for dim, sl in zip(shape, slices, strict=False):
5572
if isinstance(sl, int): # indexing removes the axis
@@ -68,11 +85,9 @@ def slice_shape(shape, slices):
6885

6986
def elementwise(*args):
7087
"""All args must broadcast elementwise."""
71-
shape = args[0]
72-
shape = shape if shape is not None else ()
73-
for s in args[1:]:
74-
shape = broadcast_shapes(shape, s) if s is not None else shape
75-
return shape
88+
if None in args:
89+
return None
90+
return broadcast_shapes(*args)
7691

7792

7893
# --- Function registry ---
@@ -118,6 +133,9 @@ def visit_Call(self, node): # noqa : C901
118133
else:
119134
kwargs[kw.arg] = self._lookup_value(kw.value)
120135

136+
if func_name in lin_alg_funcs:
137+
return None # need to implement shape handling for these funcs
138+
121139
# ------- handle constructors ---------------
122140
if func_name in constructors or attr_name == "reshape":
123141
# shape kwarg directly provided
@@ -139,7 +157,7 @@ def visit_Call(self, node): # noqa : C901
139157
if node.args:
140158
shape_arg = node.args[0]
141159
if isinstance(shape_arg, ast.Tuple):
142-
shape = tuple(self._const_or_lookup(e) for e in shape_arg.elts)
160+
shape = tuple(self._lookup_value(e) for e in shape_arg.elts)
143161
elif isinstance(shape_arg, ast.Constant):
144162
shape = (shape_arg.value,)
145163
else:
@@ -149,10 +167,10 @@ def visit_Call(self, node): # noqa : C901
149167

150168
# ---- arange ----
151169
elif func_name == "arange":
152-
start = self._const_or_lookup(node.args[0]) if node.args else 0
153-
stop = self._const_or_lookup(node.args[1]) if len(node.args) > 1 else None
154-
step = self._const_or_lookup(node.args[2]) if len(node.args) > 2 else 1
155-
shape = self._const_or_lookup(node.args[4]) if len(node.args) > 4 else kwargs.get("shape")
170+
start = self._lookup_value(node.args[0]) if node.args else 0
171+
stop = self._lookup_value(node.args[1]) if len(node.args) > 1 else None
172+
step = self._lookup_value(node.args[2]) if len(node.args) > 2 else 1
173+
shape = self._lookup_value(node.args[4]) if len(node.args) > 4 else kwargs.get("shape")
156174

157175
if shape is not None:
158176
return shape if isinstance(shape, tuple) else (shape,)
@@ -169,8 +187,8 @@ def visit_Call(self, node): # noqa : C901
169187

170188
# ---- linspace ----
171189
elif func_name == "linspace":
172-
num = self._const_or_lookup(node.args[2]) if len(node.args) > 2 else kwargs.get("num")
173-
shape = self._const_or_lookup(node.args[5]) if len(node.args) > 5 else kwargs.get("shape")
190+
num = self._lookup_value(node.args[2]) if len(node.args) > 2 else kwargs.get("num")
191+
shape = self._lookup_value(node.args[5]) if len(node.args) > 5 else kwargs.get("shape")
174192
if shape is not None:
175193
return shape if isinstance(shape, tuple) else (shape,)
176194
if num is not None:
@@ -180,12 +198,16 @@ def visit_Call(self, node): # noqa : C901
180198
elif func_name == "frombuffer" or func_name == "fromiter":
181199
count = kwargs.get("count")
182200
return (count,) if count else ()
201+
elif func_name == "eye":
202+
N = self._lookup_value(node.args[0])
203+
M = self._lookup_value(node.args[1]) if len(node.args) > 1 else kwargs.get("M")
204+
return (N, N) if M is None else (N, M)
183205

184206
elif func_name == "reshape" or attr_name == "reshape":
185207
if node.args:
186208
shape_arg = node.args[-1]
187209
if isinstance(shape_arg, ast.Tuple):
188-
return tuple(self._const_or_lookup(e) for e in shape_arg.elts)
210+
return tuple(self._lookup_value(e) for e in shape_arg.elts)
189211
return ()
190212

191213
else:
@@ -218,12 +240,13 @@ def visit_Compare(self, node):
218240
shapes = [self.visit(node.left)] + [self.visit(c) for c in node.comparators]
219241
return elementwise(*shapes)
220242

243+
def visit_Constant(self, node):
244+
return ()
245+
221246
def visit_BinOp(self, node):
222247
left = self.visit(node.left)
223248
right = self.visit(node.right)
224-
left = () if left is None else left
225-
right = () if right is None else right
226-
return broadcast_shapes(left, right)
249+
return elementwise(left, right)
227250

228251
def _eval_slice(self, node):
229252
if isinstance(node, ast.Slice):
@@ -250,15 +273,6 @@ def _lookup_value(self, node):
250273
else:
251274
return None
252275

253-
def _const_or_lookup(self, node):
254-
"""Return constant value or resolve name to scalar from shapes."""
255-
if isinstance(node, ast.Constant):
256-
return node.value
257-
elif isinstance(node, ast.Name):
258-
return self.shapes.get(node.id, None)
259-
else:
260-
return None
261-
262276

263277
# --- Public API ---
264278
def infer_shape(expr, shapes):

tests/ndarray/test_lazyexpr.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,3 +1583,19 @@ def __len__(self):
15831583
lb = blosc2.lazyexpr("b + c + 1")
15841584

15851585
np.testing.assert_array_equal(lb[:], a + a + 1)
1586+
1587+
1588+
def test_not_numexpr():
1589+
shape = (20, 20)
1590+
a = blosc2.linspace(0, 20, num=np.prod(shape), shape=shape)
1591+
b = blosc2.ones((20, 1))
1592+
d_blosc2 = blosc2.evaluate("logaddexp(a, b) + a")
1593+
npa = a[()]
1594+
npb = b[()]
1595+
np.testing.assert_array_almost_equal(d_blosc2, np.logaddexp(npa, npb) + npa)
1596+
# TODO: Implement __add__ etc. for LazyUDF so this line works
1597+
# d_blosc2 = blosc2.evaluate(f"logaddexp(a, b) + clip(a, 6, 12)")
1598+
arr = blosc2.lazyexpr("matmul(a,b) + a ")
1599+
assert isinstance(arr, blosc2.LazyExpr)
1600+
assert arr.shape is None # can't calculate shape for linalg funcs yet
1601+
np.testing.assert_array_almost_equal(arr[()], np.matmul(npa, npb) + a)

0 commit comments

Comments
 (0)