Skip to content

Commit 383d254

Browse files
Merge pull request #407 from Blosc/fixChaining
Fix chaining
2 parents c4883cb + 2553708 commit 383d254

File tree

2 files changed

+158
-46
lines changed

2 files changed

+158
-46
lines changed

src/blosc2/lazyexpr.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,107 @@ def visit_Call(self, node):
630630
return set(visitor.operands)
631631

632632

633+
def conserve_functions( # noqa: C901
634+
expression: str,
635+
operands_old: dict[str, blosc2.NDArray | blosc2.LazyExpr],
636+
operands_new: dict[str, blosc2.NDArray | blosc2.LazyExpr],
637+
) -> tuple(str, dict[str, blosc2.NDArray]):
638+
"""
639+
Given an expression in string form, return its operands.
640+
641+
Parameters
642+
----------
643+
expression : str
644+
The expression in string form.
645+
646+
operands_old: dict[str : blosc2.ndarray | blosc2.LazyExpr]
647+
Dict of operands from expression prior to eval.
648+
649+
operands_new: dict[str : blosc2.ndarray | blosc2.LazyExpr]
650+
Dict of operands from expression after eval.
651+
Returns
652+
-------
653+
newexpression
654+
A modified string expression with the functions/constructors conserved and
655+
true operands rebased and written in o- notation.
656+
newoperands
657+
Dict of the set of rebased operands.
658+
"""
659+
660+
operand_to_key = {id(v): k for k, v in operands_new.items()}
661+
for k, v in operands_old.items(): # extend operands_to_key with old operands
662+
if isinstance(
663+
v, blosc2.LazyExpr
664+
): # unroll operands in LazyExpr (only necessary when have reduced a lazyexpr)
665+
d = v.operands
666+
else:
667+
d = {k: v}
668+
for newk, newv in d.items():
669+
try:
670+
operand_to_key[id(newv)]
671+
except KeyError:
672+
newk = (
673+
f"o{len(operands_new)}" if newk in operands_new else newk
674+
) # possible that names coincide
675+
operand_to_key[id(newv)] = newk
676+
operands_new[newk] = newv
677+
678+
class OperandVisitor(ast.NodeVisitor):
679+
def __init__(self):
680+
self.operandmap = {}
681+
self.operands = {}
682+
self.opcounter = 0
683+
self.function_names = set()
684+
685+
def update_func(self, localop):
686+
k = operand_to_key[id(localop)]
687+
if k not in self.operandmap:
688+
newkey = f"o{self.opcounter}"
689+
self.operands[newkey] = operands_new[k]
690+
self.operandmap[k] = newkey
691+
self.opcounter += 1
692+
return newkey
693+
else:
694+
return self.operandmap[k]
695+
696+
def visit_Name(self, node):
697+
if node.id == "np": # Skip NumPy namespace (e.g. np.int8, which will be treated separately)
698+
return
699+
if node.id in self.function_names: # Skip function names
700+
return
701+
elif node.id not in dtype_symbols:
702+
localop = operands_old[node.id]
703+
if isinstance(localop, blosc2.LazyExpr):
704+
newexpr = localop.expression
705+
for (
706+
opname,
707+
v,
708+
) in localop.operands.items(): # expression operands already in terms of basic operands
709+
newopname = self.update_func(v)
710+
newexpr = re.sub(
711+
rf"(?<=\s){opname}|(?<=\(){opname}", newopname, newexpr
712+
) # replace with newopname
713+
node.id = newexpr
714+
else:
715+
node.id = self.update_func(localop)
716+
else:
717+
pass
718+
self.generic_visit(node)
719+
720+
def visit_Call(self, node):
721+
if isinstance(
722+
node.func, ast.Name
723+
): # visits Call first, then Name, so don't increment operandcounter yet
724+
self.function_names.add(node.func.id)
725+
self.generic_visit(node)
726+
727+
tree = ast.parse(expression)
728+
visitor = OperandVisitor()
729+
visitor.visit(tree)
730+
newexpression, newoperands = ast.unparse(tree), visitor.operands
731+
return newexpression, newoperands
732+
733+
633734
class TransformNumpyCalls(ast.NodeTransformer):
634735
def __init__(self):
635736
self.replacements = {}
@@ -2678,10 +2779,14 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
26782779
_dtype = new_expr.dtype
26792780
_shape = new_expr.shape
26802781
if isinstance(new_expr, blosc2.LazyExpr):
2681-
# Restore the original expression and operands
2682-
new_expr.expression = f"({_expression})" # forcibly add parenthesis
2683-
new_expr.expression_tosave = _expression
2684-
new_expr.operands = _operands
2782+
# DO NOT restore the original expression and operands
2783+
# Instead rebase operands and restore only constructors
2784+
expression_, operands_ = conserve_functions(
2785+
_expression, _operands, new_expr.operands | local_vars
2786+
)
2787+
new_expr.expression = f"({expression_})" # force parenthesis
2788+
new_expr.expression_tosave = expression
2789+
new_expr.operands = operands_
26852790
new_expr.operands_tosave = operands
26862791
else:
26872792
# An immediate evaluation happened (e.g. all operands are numpy arrays)

tests/ndarray/test_lazyexpr.py

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,45 +1354,52 @@ def test_chain_expressions():
13541354
le3_ = blosc2.lazyexpr("(le2 & (b < 0))", {"le2": le2_, "b": b})
13551355
assert (le3_[:] == le3[:]).all()
13561356

1357-
# TODO: This test should pass eventually
1358-
# le1 = a ** 3 + blosc2.sin(a ** 2)
1359-
# le2 = le1 < c
1360-
# le3 = (b < 0)
1361-
# le4 = le2 & le3
1362-
# le1_ = blosc2.lazyexpr("a ** 3 + sin(a ** 2)", {"a": a})
1363-
# le2_ = blosc2.lazyexpr("(le1 < c)", {"le1": le1_, "c": c})
1364-
# le3_ = blosc2.lazyexpr("(b < 0)", {"b": b})
1365-
# le4_ = blosc2.lazyexpr("(le2 & le3)", {"le2": le2_, "le3": le3_})
1366-
# assert (le4_[:] == le4[:]).all()
1367-
1368-
1369-
# TODO: Test the chaining of multiple persistent lazy expressions
1370-
# def test_chain_persistentexpressions():
1371-
# N = 1_000
1372-
# dtype = "float64"
1373-
# a = blosc2.linspace(0, 1, N * N, dtype=dtype, shape=(N, N), urlpath="a.b2nd", mode="w")
1374-
# b = blosc2.linspace(1, 2, N * N, dtype=dtype, shape=(N, N), urlpath="b.b2nd", mode="w")
1375-
# c = blosc2.linspace(0, 1, N, dtype=dtype, shape=(N,), urlpath="c.b2nd", mode="w")
1376-
#
1377-
# le1 = a ** 3 + blosc2.sin(a ** 2)
1378-
# le2 = le1 < c
1379-
# le3 = (b < 0)
1380-
# le4 = le2 & le3
1381-
#
1382-
# le1_ = blosc2.lazyexpr("a ** 3 + sin(a ** 2)", {"a": a})
1383-
# le1_.save("expr1.b2nd", mode="w")
1384-
# myle1 = blosc2.open("expr1.b2nd")
1385-
#
1386-
# le2_ = blosc2.lazyexpr("(le1 < c)", {"le1": myle1, "c": c})
1387-
# le2_.save("expr2.b2nd", mode="w")
1388-
# myle2 = blosc2.open("expr2.b2nd")
1389-
#
1390-
# le3_ = blosc2.lazyexpr("(b < 0)", {"b": b})
1391-
# le3_.save("expr3.b2nd", mode="w")
1392-
# myle3 = blosc2.open("expr3.b2nd")
1393-
#
1394-
# le4_ = blosc2.lazyexpr("(le2 & le3)", {"le2": myle2, "le3": myle3})
1395-
# le4_.save("expr4.b2nd", mode="w")
1396-
# myle4 = blosc2.open("expr4.b2nd")
1397-
# print((myle4[:] == le4[:]).all())
1398-
#
1357+
le1 = a**3 + blosc2.sin(a**2)
1358+
le2 = le1 < c
1359+
le3 = b < 0
1360+
le4 = le2 & le3
1361+
le1_ = blosc2.lazyexpr("a ** 3 + sin(a ** 2)", {"a": a})
1362+
le2_ = blosc2.lazyexpr("(le1 < c)", {"le1": le1_, "c": c})
1363+
le3_ = blosc2.lazyexpr("(b < 0)", {"b": b})
1364+
le4_ = blosc2.lazyexpr("(le2 & le3)", {"le2": le2_, "le3": le3_})
1365+
assert (le4_[:] == le4[:]).all()
1366+
1367+
# TODO: Eventually this test should pass
1368+
# expr1 = blosc2.lazyexpr("arange(N) + b")
1369+
# expr2 = blosc2.lazyexpr("a * b + 1")
1370+
# expr = blosc2.lazyexpr("expr1 - expr2")
1371+
# expr_final = blosc2.lazyexpr("expr * expr")
1372+
# nres = (expr * expr)[:]
1373+
# res = expr_final.compute()
1374+
# np.testing.assert_allclose(res[:], nres)
1375+
1376+
1377+
# Test the chaining of multiple persistent lazy expressions
1378+
def test_chain_persistentexpressions():
1379+
N = 1_000
1380+
dtype = "float64"
1381+
a = blosc2.linspace(0, 1, N * N, dtype=dtype, shape=(N, N), urlpath="a.b2nd", mode="w")
1382+
b = blosc2.linspace(1, 2, N * N, dtype=dtype, shape=(N, N), urlpath="b.b2nd", mode="w")
1383+
c = blosc2.linspace(0, 1, N, dtype=dtype, shape=(N,), urlpath="c.b2nd", mode="w")
1384+
1385+
le1 = a**3 + blosc2.sin(a**2)
1386+
le2 = le1 < c
1387+
le3 = b < 0
1388+
le4 = le2 & le3
1389+
1390+
le1_ = blosc2.lazyexpr("a ** 3 + sin(a ** 2)", {"a": a})
1391+
le1_.save("expr1.b2nd", mode="w")
1392+
myle1 = blosc2.open("expr1.b2nd")
1393+
1394+
le2_ = blosc2.lazyexpr("(le1 < c)", {"le1": myle1, "c": c})
1395+
le2_.save("expr2.b2nd", mode="w")
1396+
myle2 = blosc2.open("expr2.b2nd")
1397+
1398+
le3_ = blosc2.lazyexpr("(b < 0)", {"b": b})
1399+
le3_.save("expr3.b2nd", mode="w")
1400+
myle3 = blosc2.open("expr3.b2nd")
1401+
1402+
le4_ = blosc2.lazyexpr("(le2 & le3)", {"le2": myle2, "le3": myle3})
1403+
le4_.save("expr4.b2nd", mode="w")
1404+
myle4 = blosc2.open("expr4.b2nd")
1405+
assert (myle4[:] == le4[:]).all()

0 commit comments

Comments
 (0)