Skip to content

Commit cfdbd4d

Browse files
author
Luke Shaw
committed
Almost fixed operand handling
1 parent 258b1bc commit cfdbd4d

File tree

2 files changed

+76
-48
lines changed

2 files changed

+76
-48
lines changed

src/blosc2/lazyexpr.py

Lines changed: 68 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -630,49 +630,69 @@ def visit_Call(self, node):
630630
return set(visitor.operands)
631631

632632

633-
def conserve_functions(
634-
expression: str,
635-
operands_old: dict[str : blosc2.ndarray | blosc2.LazyExpr],
636-
operands_new: dict[str : blosc2.ndarray | blosc2.LazyExpr],
637-
) -> tuple(str, dict[str : blosc2.ndarray.NDArray | blosc2.LazyExpr]):
638-
operand_to_key = {id(v): k for k, v in operands_new.items()}
639-
640-
class OperandVisitor(ast.NodeVisitor):
641-
def __init__(self):
642-
self.operandset = set()
643-
self.operands = {}
644-
self.opcounter = 0
645-
self.function_names = set()
646-
647-
def visit_Name(self, node):
648-
if node.id == "np":
649-
# Skip NumPy namespace (e.g. np.int8, which will be treated separately)
650-
return
651-
if node.id in self.function_names:
652-
# Skip function names
653-
return
654-
elif (node.id not in dtype_symbols) and (node.id not in self.operandset):
655-
k = operand_to_key[id(operands_old[node.id])]
656-
newkey = f"o{self.opcounter}"
657-
self.operands[newkey] = operands_new[k]
658-
node.id = newkey
659-
self.operandset.add(node.id)
660-
self.opcounter += 1
661-
else:
662-
pass
663-
self.generic_visit(node)
664-
665-
def visit_Call(self, node):
666-
if isinstance(
667-
node.func, ast.Name
668-
): # visits Call first, then Name, so don't increment operandcounter yet
669-
self.function_names.add(node.func.id)
670-
self.generic_visit(node)
671-
672-
tree = ast.parse(expression)
673-
visitor = OperandVisitor()
674-
visitor.visit(tree)
675-
return ast.unparse(tree), visitor.operands
633+
#
634+
# def cons_functions(
635+
# expression: str,
636+
# operands_old: dict[str : blosc2.ndarray | blosc2.LazyExpr],
637+
# operands_new: dict[str : blosc2.ndarray | blosc2.LazyExpr],
638+
# ) -> tuple(str, dict[str : blosc2.ndarray.NDArray | blosc2.LazyExpr]):
639+
# operand_to_key = {id(v): k for k, v in operands_new.items()}
640+
# for k, v in operands_old.items(): # extend operands_to_key with old operands
641+
# try:
642+
# operand_to_key[id(v)]
643+
# except KeyError:
644+
# operand_to_key[id(v)] = k
645+
# operands_new[k] = v
646+
#
647+
# class OperandVisitor(ast.NodeVisitor):
648+
# def __init__(self):
649+
# self.operandmap = {}
650+
# self.operands = {}
651+
# self.opcounter = 0
652+
# self.function_names = set()
653+
#
654+
# def update_func(self, localop):
655+
# k = operand_to_key[id(localop)]
656+
# if k not in self.operandmap:
657+
# newkey = f"o{self.opcounter}"
658+
# self.operands[newkey] = operands_new[k]
659+
# self.operandmap[k] = newkey
660+
# self.opcounter += 1
661+
# return newkey
662+
# else:
663+
# return self.operandmap[k]
664+
#
665+
# def visit_Name(self, node):
666+
# if node.id == "np": # Skip NumPy namespace (e.g. np.int8, which will be treated separately)
667+
# return
668+
# if node.id in self.function_names: # Skip function names
669+
# return
670+
# elif node.id not in dtype_symbols:
671+
# localop = operands_old[node.id]
672+
# if isinstance(localop, blosc2.LazyExpr):
673+
# for (
674+
# _,
675+
# v,
676+
# ) in localop.operands.items(): # expression operands already in terms of basic operands
677+
# _ = self.update_func(v)
678+
# node.id = localop.expression
679+
# else:
680+
# node.id = self.update_func(localop)
681+
# else:
682+
# pass
683+
# self.generic_visit(node)
684+
#
685+
# def visit_Call(self, node):
686+
# if isinstance(
687+
# node.func, ast.Name
688+
# ): # visits Call first, then Name, so don't increment operandcounter yet
689+
# self.function_names.add(node.func.id)
690+
# self.generic_visit(node)
691+
#
692+
# tree = ast.parse(expression)
693+
# visitor = OperandVisitor()
694+
# visitor.visit(tree)
695+
# return ast.unparse(tree), visitor.operands
676696

677697

678698
class TransformNumpyCalls(ast.NodeTransformer):
@@ -2725,12 +2745,12 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
27252745
if isinstance(new_expr, blosc2.LazyExpr):
27262746
# DO NOT restore the original expression and operands
27272747
# Instead rebase operands and restore only constructors
2728-
expression_, operands_ = conserve_functions(
2729-
_expression, _operands, new_expr.operands | local_vars
2730-
)
2731-
new_expr.expression = f"({expression_})" # force parenthesis
2748+
# expression_, operands_ = cons_functions(
2749+
# _expression, _operands, new_expr.operands | local_vars
2750+
# )
2751+
# new_expr.expression = f"({expression_})" # force parenthesis
27322752
new_expr.expression_tosave = expression
2733-
new_expr.operands = operands_
2753+
# new_expr.operands = operands_
27342754
new_expr.operands_tosave = operands
27352755
else:
27362756
# An immediate evaluation happened (e.g. all operands are numpy arrays)

tests/ndarray/test_lazyexpr.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,14 @@ def test_chain_expressions():
13641364
le4_ = blosc2.lazyexpr("(le2 & le3)", {"le2": le2_, "le3": le3_})
13651365
assert (le4_[:] == le4[:]).all()
13661366

1367+
expr1 = blosc2.lazyexpr("arange(N) + b")
1368+
expr2 = blosc2.lazyexpr("a * b + 1")
1369+
expr = blosc2.lazyexpr("expr1 - expr2")
1370+
expr_final = blosc2.lazyexpr("expr * expr")
1371+
nres = (expr * expr)[:]
1372+
res = expr_final.compute()
1373+
np.testing.assert_allclose(res[:], nres)
1374+
13671375

13681376
# Test the chaining of multiple persistent lazy expressions
13691377
def test_chain_persistentexpressions():

0 commit comments

Comments
 (0)