Skip to content

Commit 38c6bd3

Browse files
committed
Add Francesc's suggestions, improve tests
1 parent 6a25957 commit 38c6bd3

File tree

3 files changed

+83
-27
lines changed

3 files changed

+83
-27
lines changed

src/blosc2/lazyexpr.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353

5454
from .shape_utils import constructors, infer_shape, lin_alg_funcs, reducers
5555

56-
lin_alg_funcs += ("clip", "logaddexp")
56+
not_numexpr_funcs = lin_alg_funcs + ("clip", "logaddexp")
5757

5858
if not blosc2.IS_WASM:
5959
import numexpr
@@ -645,8 +645,11 @@ def validate_expr(expr: str) -> None:
645645
skip_quotes = re.sub(r"(\'[^\']*\')", "", no_whitespace)
646646

647647
# Check for forbidden patterns
648-
if _blacklist_re.search(skip_quotes) is not None:
649-
raise ValueError(f"'{expr}' is not a valid expression.")
648+
forbiddens = _blacklist_re.search(skip_quotes)
649+
if forbiddens is not None:
650+
i = forbiddens.span()[0]
651+
if expr[i : i + 2] != ".T" and expr[i : i + 3] != ".mT": # allow tranpose methods
652+
raise ValueError(f"'{expr}' is not a valid expression.")
650653

651654
# Check for invalid characters not covered by the tokenizer
652655
invalid_chars = re.compile(r"[^\w\s+\-*/%()[].,=<>!&|~^]")
@@ -706,7 +709,7 @@ def visit_Call(self, node):
706709
try:
707710
shape = infer_shape(full_expr, shapes)
708711
except Exception as e:
709-
print(f"⚠️ Shape inference failed for {full_expr}: {e}")
712+
print(f"Shape inference failed for {full_expr}: {e}")
710713
shape = ()
711714

712715
# Determine dtype
@@ -2920,7 +2923,7 @@ def find_args(expr):
29202923
return value, expression[idx:idx2]
29212924

29222925
def _compute_expr(self, item, kwargs):
2923-
if any(method in self.expression for method in reducers + lin_alg_funcs):
2926+
if any(method in self.expression for method in reducers + not_numexpr_funcs):
29242927
# We have reductions in the expression (probably coming from a string lazyexpr)
29252928
# Also includes slice
29262929
_globals = get_expr_globals(self.expression)
@@ -3441,8 +3444,8 @@ def _numpy_eval_expr(expression, operands, prefer_blosc=False):
34413444
if "contains" in expression:
34423445
_out = ne_evaluate(expression, local_dict=ops)
34433446
else:
3444-
# Create a globals dict with the functions of blosc2 preferentially
3445-
# (and numpy if can't find blosc2)
3447+
# Create a globals dict with blosc2 version of functions preferentially
3448+
# (default to numpy func if not implemented in blosc2)
34463449
if prefer_blosc:
34473450
_globals = get_expr_globals(expression)
34483451
_globals |= dtype_symbols

src/blosc2/shape_utils.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ def linalg_shape(func_name, args, kwargs): # noqa: C901
104104
batch = broadcast_shapes(a[:-2], b[:-2])
105105
shape = batch
106106
if not x1_is_vector:
107-
shape += a[-2]
107+
shape += (a[-2],)
108108
if not x2_is_vector:
109-
shape += b[-1]
109+
shape += (b[-1],)
110110
return shape
111111

112112
# --- matrix_transpose ---
@@ -154,7 +154,7 @@ def linalg_shape(func_name, args, kwargs): # noqa: C901
154154
elif func_name == "tensordot":
155155
if axes is None and len(args) > 2:
156156
axes = args[2]
157-
if axis is None:
157+
if axes is None:
158158
axes = 2
159159
if b is None:
160160
return None
@@ -168,7 +168,7 @@ def linalg_shape(func_name, args, kwargs): # noqa: C901
168168
return a_rest + b_rest
169169

170170
# --- transpose ---
171-
elif func_name == ("transpose", "T", "mT"):
171+
elif func_name in ("transpose", "T", "mT"):
172172
return a[:-2] + (a[-1], a[-2])
173173

174174
# --- vecdot ---
@@ -261,9 +261,22 @@ def visit_Name(self, node):
261261
else: # passed a scalar value
262262
return ()
263263

264+
def visit_Attribute(self, node):
265+
obj_shape = self.visit(node.value)
266+
attr = node.attr
267+
if attr == "reshape":
268+
if node.args:
269+
shape_arg = node.args[-1]
270+
if isinstance(shape_arg, ast.Tuple):
271+
return tuple(self._lookup_value(e) for e in shape_arg.elts)
272+
return ()
273+
elif attr in ("T", "mT"):
274+
return linalg_shape(attr, (obj_shape,), {})
275+
return None
276+
264277
def visit_Call(self, node): # noqa : C901
265278
func_name = getattr(node.func, "id", None)
266-
attr_name = getattr(node.func, "attr", None)
279+
attr_name = getattr(node.func, "attr", None) # handle methods called on funcs
267280

268281
# --- Recursive method-chain support ---
269282
obj_shape = None

tests/ndarray/test_lazyexpr.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,62 +1629,102 @@ def test_lazylinalg():
16291629
npx = x[()]
16301630
npy = y[()]
16311631
npA = A[()]
1632+
npB = B[()]
1633+
npC = C[()]
1634+
npD = D[()]
16321635

16331636
# --- concat ---
16341637
out = blosc2.lazyexpr("concat((x, y), axis=0)")
1635-
assert out.shape == np.concat((npx, npy), axis=0).shape
1638+
npres = np.concatenate((npx, npy), axis=0)
1639+
assert out.shape == npres.shape
1640+
np.testing.assert_array_almost_equal(out[()], npres)
16361641

16371642
# --- diagonal ---
16381643
out = blosc2.lazyexpr("diagonal(A)")
1639-
assert out.shape == np.diagonal(npA).shape
1644+
npres = np.diagonal(npA)
1645+
assert out.shape == npres.shape
1646+
np.testing.assert_array_almost_equal(out[()], npres)
16401647

16411648
# --- expand_dims ---
16421649
out = blosc2.lazyexpr("expand_dims(x, axis=0)")
1643-
assert out.shape == (1,) + shapes["x"]
1650+
npres = np.expand_dims(npx, axis=0)
1651+
assert out.shape == npres.shape
1652+
np.testing.assert_array_almost_equal(out[()], npres)
16441653

16451654
# --- matmul ---
16461655
out = blosc2.lazyexpr("matmul(A, B)")
1647-
assert out.shape == (shapes["A"][0], shapes["B"][1])
1656+
npres = np.matmul(npA, npB)
1657+
assert out.shape == npres.shape
1658+
np.testing.assert_array_almost_equal(out[()], npres)
16481659

16491660
# --- matrix_transpose ---
16501661
out = blosc2.lazyexpr("matrix_transpose(A)")
1651-
assert out.shape == (shapes["A"][1], shapes["A"][0])
1662+
npres = np.matrix_transpose(npA)
1663+
assert out.shape == npres.shape
1664+
np.testing.assert_array_almost_equal(out[()], npres)
1665+
out = blosc2.lazyexpr("C.mT")
1666+
npres = C.mT
1667+
assert out.shape == npres.shape
1668+
np.testing.assert_array_almost_equal(out[()], npres)
1669+
out = blosc2.lazyexpr("A.T")
1670+
npres = npA.T
1671+
assert out.shape == npres.shape
1672+
np.testing.assert_array_almost_equal(out[()], npres)
16521673

16531674
# --- outer ---
16541675
out = blosc2.lazyexpr("outer(x, y)")
1655-
assert out.shape == shapes["x"] + shapes["y"]
1676+
npres = np.outer(npx, npy)
1677+
assert out.shape == npres.shape
1678+
np.testing.assert_array_almost_equal(out[()], npres)
16561679

16571680
# --- permute_dims ---
16581681
out = blosc2.lazyexpr("permute_dims(C, axes=(2,0,1))")
1659-
assert out.shape == (shapes["C"][2], shapes["C"][0], shapes["C"][1])
1682+
npres = np.transpose(npC, axes=(2, 0, 1))
1683+
assert out.shape == npres.shape
1684+
np.testing.assert_array_almost_equal(out[()], npres)
16601685

16611686
# --- squeeze ---
16621687
out = blosc2.lazyexpr("squeeze(D)")
1663-
assert out.shape == (5,)
1688+
npres = np.squeeze(npD)
1689+
assert out.shape == npres.shape
1690+
np.testing.assert_array_almost_equal(out[()], npres)
1691+
16641692
out = blosc2.lazyexpr("D.squeeze()")
1665-
assert out.shape == (5,)
1693+
npres = np.squeeze(npD)
1694+
assert out.shape == npres.shape
1695+
np.testing.assert_array_almost_equal(out[()], npres)
16661696

16671697
# --- stack ---
16681698
out = blosc2.lazyexpr("stack((x, y), axis=0)")
1669-
assert out.shape == (2,) + shapes["x"]
1699+
npres = np.stack((npx, npy), axis=0)
1700+
assert out.shape == npres.shape
1701+
np.testing.assert_array_almost_equal(out[()], npres)
16701702

16711703
# --- tensordot ---
16721704
out = blosc2.lazyexpr("tensordot(A, B, axes=1)")
1673-
assert out.shape[0] == shapes["A"][0]
1674-
assert out.shape[-1] == shapes["B"][-1]
1705+
npres = np.tensordot(npA, npB, axes=1)
1706+
assert out.shape == npres.shape
1707+
np.testing.assert_array_almost_equal(out[()], npres)
16751708

16761709
# --- vecdot ---
16771710
out = blosc2.lazyexpr("vecdot(x, y)")
1678-
assert out.shape == np.vecdot(x[()], y[()]).shape
1711+
npres = np.vecdot(npx, npy)
1712+
assert out.shape == npres.shape
1713+
np.testing.assert_array_almost_equal(out[()], npres)
16791714

1680-
# batched matmul
1715+
# --- batched matmul ---
16811716
shapes = {
16821717
"A": (1, 3, 4),
16831718
"B": (3, 4, 5),
16841719
}
16851720
s = shapes["A"]
16861721
A = blosc2.linspace(0, np.prod(s), shape=s)
1722+
npA = A[()] # actual numpy array
16871723
s = shapes["B"]
16881724
B = blosc2.linspace(0, np.prod(s), shape=s)
1725+
npB = B[()] # actual numpy array
1726+
16891727
out = blosc2.lazyexpr("matmul(A, B)")
1690-
assert out.shape == (3, 3, 5)
1728+
npres = np.matmul(npA, npB)
1729+
assert out.shape == npres.shape
1730+
np.testing.assert_array_almost_equal(out[()], npres)

0 commit comments

Comments
 (0)