Skip to content

Commit 6a25957

Browse files
committed
Add shape parsing of linalg funcs
1 parent 6e5d0cd commit 6a25957

File tree

4 files changed

+273
-32
lines changed

4 files changed

+273
-32
lines changed

src/blosc2/lazyexpr.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,11 @@ def compute_smaller_slice(larger_shape, smaller_shape, larger_slice):
620620
valid_methods |= {"int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"}
621621
valid_methods |= {"float32", "float64", "complex64", "complex128"}
622622
valid_methods |= {"bool", "str", "bytes"}
623+
valid_methods |= {
624+
name
625+
for name in dir(blosc2.NDArray)
626+
if callable(getattr(blosc2.NDArray, name)) and not name.startswith("_")
627+
}
623628

624629

625630
def validate_expr(expr: str) -> None:
@@ -2002,7 +2007,7 @@ def reduce_slices( # noqa: C901
20022007
continue
20032008

20042009
if where is None:
2005-
if expression == "o0":
2010+
if expression == "o0" or expression == "(o0)":
20062011
# We don't have an actual expression, so avoid a copy except to make contiguous
20072012
result = np.require(chunk_operands["o0"], requirements="C")
20082013
else:
@@ -3168,9 +3173,6 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
31683173
# in guessing mode to avoid computing reductions
31693174
# Extract possible numpy scalars
31703175
_expression, local_vars = extract_numpy_scalars(expression)
3171-
# Let's include numpy and blosc2 as operands so that some functions can be used
3172-
# Most in particular, castings like np.int8 et al. can be very useful to allow
3173-
# for desired data types in the output.
31743176
_operands = operands | local_vars
31753177
# Check that operands are proper Operands, LazyArray or scalars; if not, convert to NDArray objects
31763178
for op, val in _operands.items():
@@ -3179,10 +3181,10 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
31793181
# for scalars just return value (internally converts to () if necessary)
31803182
opshapes = {k: v if not hasattr(v, "shape") else v.shape for k, v in _operands.items()}
31813183
_shape = infer_shape(_expression, opshapes) # infer shape, includes constructors
3182-
# substitutes with numpy operands (cheap for reductions) and
3183-
# defaults to blosc2 functions (cheap for constructors)
31843184
# have to handle slices since a[10] on a dummy variable of shape (1,1) doesn't work
31853185
desliced_expr, desliced_ops = extract_and_replace_slices(_expression, _operands)
3186+
# substitutes with dummy operands (cheap for reductions) and
3187+
# defaults to blosc2 functions (cheap for constructors)
31863188
new_expr = _numpy_eval_expr(desliced_expr, desliced_ops, prefer_blosc=True)
31873189
_dtype = new_expr.dtype
31883190
if isinstance(new_expr, blosc2.LazyExpr):
@@ -3205,24 +3207,16 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
32053207
if counter == 0 and char == ",":
32063208
break
32073209
expression_ = finalexpr[:-1] # remove trailing comma
3208-
new_expr.expression = f"({expression_})" # force parenthesis
3209-
new_expr.expression_tosave = expression
3210-
new_expr.operands = operands_
3211-
new_expr.operands_tosave = operands
3212-
elif isinstance(new_expr, blosc2.NDArray) and len(operands) == 1:
3213-
# passed "a", "a[:10]", 'sum(a)'
3214-
expression_, operands_ = conserve_functions(
3215-
_expression, _operands, {"o0": list(operands.values())[0]} | local_vars
3216-
)
3217-
new_expr = cls(None)
3218-
new_expr.expression = expression_
3219-
new_expr.operands = operands_
32203210
else:
3211+
new_expr = cls(None)
32213212
# An immediate evaluation happened
32223213
# (e.g. all operands are numpy arrays or constructors)
3223-
new_expr = cls(None)
3224-
new_expr.expression = expression
3225-
new_expr.operands = operands
3214+
# or passed "a", "a[:10]", 'sum(a)'
3215+
expression_, operands_ = conserve_functions(_expression, _operands, local_vars)
3216+
new_expr.expression = f"({expression_})" # force parenthesis
3217+
new_expr.operands = operands_
3218+
new_expr.expression_tosave = expression
3219+
new_expr.operands_tosave = operands
32263220
# Cache the dtype and shape (should be immutable)
32273221
new_expr._dtype = _dtype
32283222
new_expr._shape = _shape

src/blosc2/linalg.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -353,10 +353,6 @@ def vecdot(x1: blosc2.NDArray, x2: blosc2.NDArray, axis: int = -1, **kwargs) ->
353353
a_keep[a_axes] = False
354354
b_keep = [True] * x2.ndim
355355
b_keep[b_axes] = False
356-
x1shape = np.array(x1.shape)
357-
x2shape = np.array(x2.shape)
358-
result_shape = np.broadcast_shapes(x1shape[a_keep], x2shape[b_keep])
359-
result = blosc2.zeros(result_shape, dtype=np.result_type(x1, x2), **kwargs)
360356

361357
x1shape = np.array(x1.shape)
362358
x2shape = np.array(x2.shape)

src/blosc2/shape_utils.py

Lines changed: 164 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
import builtins
23

34
from numpy import broadcast_shapes
45

@@ -15,6 +16,8 @@
1516
"tensordot",
1617
"transpose",
1718
"vecdot",
19+
"T",
20+
"mT",
1821
)
1922
reducers = ("sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "slice", "count_nonzero")
2023

@@ -39,6 +42,152 @@
3942

4043

4144
# --- Shape utilities ---
45+
def linalg_shape(func_name, args, kwargs): # noqa: C901
46+
# --- Linear algebra and tensor manipulation ---
47+
a = args[0] if args else None
48+
if a is None or any(s is None for s in a):
49+
return None
50+
b = args[1] if len(args) > 1 else None
51+
axis = kwargs.get("axis", None)
52+
axes = kwargs.get("axes", None)
53+
offset = kwargs.get("offset", 0)
54+
55+
# --- concat ---
56+
if func_name == "concat":
57+
shapes = args[0]
58+
if axis is None and len(args) > 1:
59+
axis = args[1]
60+
61+
# Coerce axis to int if tuple single-element
62+
axis = 0 if axis is None else axis
63+
# normalize negative axis
64+
axis = axis + len(shapes[0]) if axis < 0 else axis
65+
concat_dim = builtins.sum([s[axis] for s in shapes])
66+
return tuple(s if i != axis else concat_dim for i, s in enumerate(shapes[0]))
67+
68+
# --- diagonal ---
69+
elif func_name == "diagonal":
70+
axis1 = len(a) - 2
71+
axis2 = len(a) - 1
72+
new_shape = [d for i, d in enumerate(a) if i not in (axis1, axis2)]
73+
d1, d2 = a[axis1], a[axis2]
74+
diag_len = builtins.max(0, min(d1, d2) - abs(offset))
75+
new_shape.append(diag_len)
76+
return tuple(new_shape)
77+
78+
# --- expand_dims ---
79+
elif func_name == "expand_dims":
80+
# positional axis may be second positional argument
81+
if axis is None and len(args) > 1:
82+
axis = args[1]
83+
if axis is None:
84+
axis = 0
85+
axis = [axis] if isinstance(axis, int) else axis
86+
new_shape = list(a)
87+
for ax in sorted(axis):
88+
ax = ax if ax >= 0 else len(new_shape) + ax + 1
89+
new_shape.insert(ax, 1)
90+
return tuple(new_shape)
91+
92+
# --- matmul ---
93+
elif func_name == "matmul":
94+
if b is None:
95+
return None
96+
x1_is_vector = False
97+
x2_is_vector = False
98+
if len(a) == 1:
99+
a = (1,) + a # (N,) -> (1, N)
100+
x1_is_vector = True
101+
if len(b) == 1:
102+
b += (1,) # (M,) -> (M, 1)
103+
x2_is_vector = True
104+
batch = broadcast_shapes(a[:-2], b[:-2])
105+
shape = batch
106+
if not x1_is_vector:
107+
shape += a[-2]
108+
if not x2_is_vector:
109+
shape += b[-1]
110+
return shape
111+
112+
# --- matrix_transpose ---
113+
elif func_name == "matrix_transpose":
114+
if len(a) < 2:
115+
return a
116+
return a[:-2] + (a[-1], a[-2])
117+
118+
# --- outer ---
119+
elif func_name == "outer":
120+
if b is None:
121+
return None
122+
return a + b
123+
124+
# --- permute_dims ---
125+
elif func_name == "permute_dims":
126+
if axes is None and len(args) > 1:
127+
axes = args[1]
128+
if axes is None:
129+
axes = tuple(reversed(range(len(a))))
130+
return tuple(a[i] for i in axes)
131+
132+
# --- squeeze ---
133+
elif func_name == "squeeze":
134+
if axis is None and len(args) > 1:
135+
axis = args[1]
136+
if axis is None:
137+
return tuple(d for d in a if d != 1)
138+
if isinstance(axis, int):
139+
axis = (axis,)
140+
axis = tuple(ax if ax >= 0 else len(a) + ax for ax in axis)
141+
return tuple(d for i, d in enumerate(a) if i not in axis or d != 1)
142+
143+
# --- stack ---
144+
elif func_name == "stack":
145+
# detect axis as last positional if candidate
146+
elems = args[0]
147+
if axis is None and len(args) > 1:
148+
axis = args[1]
149+
if axis is None:
150+
axis = 0
151+
return elems[0][:axis] + (len(elems),) + elems[0][axis:]
152+
153+
# --- tensordot ---
154+
elif func_name == "tensordot":
155+
if axes is None and len(args) > 2:
156+
axes = args[2]
157+
if axis is None:
158+
axes = 2
159+
if b is None:
160+
return None
161+
if isinstance(axes, int):
162+
a_rest = a[:-axes]
163+
b_rest = b[axes:]
164+
else:
165+
a_axes, b_axes = axes
166+
a_rest = tuple(d for i, d in enumerate(a) if i not in a_axes)
167+
b_rest = tuple(d for i, d in enumerate(b) if i not in b_axes)
168+
return a_rest + b_rest
169+
170+
# --- transpose ---
171+
elif func_name == ("transpose", "T", "mT"):
172+
return a[:-2] + (a[-1], a[-2])
173+
174+
# --- vecdot ---
175+
elif func_name == "vecdot":
176+
if axis is None and len(args) > 2:
177+
axis = args[2]
178+
if axis is None:
179+
axis = -1
180+
if b is None:
181+
return None
182+
a_axis = axis + len(a)
183+
b_axis = axis + len(b)
184+
a_rem = tuple(d for i, d in enumerate(a) if i != a_axis)
185+
b_rem = tuple(d for i, d in enumerate(b) if i != b_axis)
186+
return broadcast_shapes(a_rem, b_rem)
187+
else:
188+
return None
189+
190+
42191
def reduce_shape(shape, axis, keepdims):
43192
"""Reduce shape along given axis or axes (collapse dimensions)."""
44193
if shape is None:
@@ -133,8 +282,18 @@ def visit_Call(self, node): # noqa : C901
133282
else:
134283
kwargs[kw.arg] = self._lookup_value(kw.value)
135284

285+
# ------- handle linear algebra ---------------
286+
target = None
136287
if func_name in lin_alg_funcs:
137-
return None # need to implement shape handling for these funcs
288+
target = func_name
289+
if attr_name in lin_alg_funcs:
290+
target = attr_name
291+
if target is not None:
292+
args = [self.visit(arg) for arg in node.args]
293+
# If it's a method call, prepend the object shape
294+
if obj_shape is not None and attr_name == target:
295+
args.insert(0, obj_shape)
296+
return linalg_shape(target, args, kwargs)
138297

139298
# ------- handle constructors ---------------
140299
if func_name in constructors or attr_name == "reshape":
@@ -241,7 +400,10 @@ def visit_Compare(self, node):
241400
return elementwise(*shapes)
242401

243402
def visit_Constant(self, node):
244-
return ()
403+
return () if not hasattr(node.value, "shape") else node.value.shape
404+
405+
def visit_Tuple(self, node):
406+
return tuple(self.visit(arg) for arg in node.elts)
245407

246408
def visit_BinOp(self, node):
247409
left = self.visit(node.left)

tests/ndarray/test_lazyexpr.py

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,10 +1129,10 @@ def test_rebasing(array_fixture):
11291129
assert expr.expression == "(o0 + o1 - o2 * o3)"
11301130

11311131
expr = blosc2.lazyexpr("a1")
1132-
assert expr.expression == "o0"
1132+
assert expr.expression == "(o0)"
11331133

11341134
expr = blosc2.lazyexpr("a1[:10]")
1135-
assert expr.expression == "o0.slice((slice(None, 10, None),))"
1135+
assert expr.expression == "(o0.slice((slice(None, 10, None),)))"
11361136

11371137

11381138
# Test get_chunk method
@@ -1595,7 +1595,96 @@ def test_not_numexpr():
15951595
np.testing.assert_array_almost_equal(d_blosc2, np.logaddexp(npa, npb) + npa)
15961596
# TODO: Implement __add__ etc. for LazyUDF so this line works
15971597
# d_blosc2 = blosc2.evaluate(f"logaddexp(a, b) + clip(a, 6, 12)")
1598-
arr = blosc2.lazyexpr("matmul(a,b) + a ")
1598+
arr = blosc2.lazyexpr("matmul(a, b)")
15991599
assert isinstance(arr, blosc2.LazyExpr)
1600-
assert arr.shape is None # can't calculate shape for linalg funcs yet
1601-
np.testing.assert_array_almost_equal(arr[()], np.matmul(npa, npb) + a)
1600+
np.testing.assert_array_almost_equal(arr[()], np.matmul(npa, npb))
1601+
1602+
1603+
def test_lazylinalg():
1604+
"""
1605+
Test the shape parser for linear algebra funcs
1606+
"""
1607+
# --- define base shapes ---
1608+
shapes = {
1609+
"A": (3, 4),
1610+
"B": (4, 5),
1611+
"C": (2, 3, 4),
1612+
"D": (1, 5, 1),
1613+
"x": (10,),
1614+
"y": (10,),
1615+
}
1616+
s = shapes["x"]
1617+
x = blosc2.linspace(0, np.prod(s), shape=s)
1618+
s = shapes["y"]
1619+
y = blosc2.linspace(0, np.prod(s), shape=s)
1620+
s = shapes["A"]
1621+
A = blosc2.linspace(0, np.prod(s), shape=s)
1622+
s = shapes["B"]
1623+
B = blosc2.linspace(0, np.prod(s), shape=s)
1624+
s = shapes["C"]
1625+
C = blosc2.linspace(0, np.prod(s), shape=s)
1626+
s = shapes["D"]
1627+
D = blosc2.linspace(0, np.prod(s), shape=s)
1628+
1629+
npx = x[()]
1630+
npy = y[()]
1631+
npA = A[()]
1632+
1633+
# --- concat ---
1634+
out = blosc2.lazyexpr("concat((x, y), axis=0)")
1635+
assert out.shape == np.concat((npx, npy), axis=0).shape
1636+
1637+
# --- diagonal ---
1638+
out = blosc2.lazyexpr("diagonal(A)")
1639+
assert out.shape == np.diagonal(npA).shape
1640+
1641+
# --- expand_dims ---
1642+
out = blosc2.lazyexpr("expand_dims(x, axis=0)")
1643+
assert out.shape == (1,) + shapes["x"]
1644+
1645+
# --- matmul ---
1646+
out = blosc2.lazyexpr("matmul(A, B)")
1647+
assert out.shape == (shapes["A"][0], shapes["B"][1])
1648+
1649+
# --- matrix_transpose ---
1650+
out = blosc2.lazyexpr("matrix_transpose(A)")
1651+
assert out.shape == (shapes["A"][1], shapes["A"][0])
1652+
1653+
# --- outer ---
1654+
out = blosc2.lazyexpr("outer(x, y)")
1655+
assert out.shape == shapes["x"] + shapes["y"]
1656+
1657+
# --- permute_dims ---
1658+
out = blosc2.lazyexpr("permute_dims(C, axes=(2,0,1))")
1659+
assert out.shape == (shapes["C"][2], shapes["C"][0], shapes["C"][1])
1660+
1661+
# --- squeeze ---
1662+
out = blosc2.lazyexpr("squeeze(D)")
1663+
assert out.shape == (5,)
1664+
out = blosc2.lazyexpr("D.squeeze()")
1665+
assert out.shape == (5,)
1666+
1667+
# --- stack ---
1668+
out = blosc2.lazyexpr("stack((x, y), axis=0)")
1669+
assert out.shape == (2,) + shapes["x"]
1670+
1671+
# --- tensordot ---
1672+
out = blosc2.lazyexpr("tensordot(A, B, axes=1)")
1673+
assert out.shape[0] == shapes["A"][0]
1674+
assert out.shape[-1] == shapes["B"][-1]
1675+
1676+
# --- vecdot ---
1677+
out = blosc2.lazyexpr("vecdot(x, y)")
1678+
assert out.shape == np.vecdot(x[()], y[()]).shape
1679+
1680+
# batched matmul
1681+
shapes = {
1682+
"A": (1, 3, 4),
1683+
"B": (3, 4, 5),
1684+
}
1685+
s = shapes["A"]
1686+
A = blosc2.linspace(0, np.prod(s), shape=s)
1687+
s = shapes["B"]
1688+
B = blosc2.linspace(0, np.prod(s), shape=s)
1689+
out = blosc2.lazyexpr("matmul(A, B)")
1690+
assert out.shape == (3, 3, 5)

0 commit comments

Comments
 (0)