Skip to content

Commit bcd1b09

Browse files
author
Luke Shaw
committed
Finalise handling of reductions
1 parent 084fa90 commit bcd1b09

File tree

2 files changed

+31
-11
lines changed

2 files changed

+31
-11
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ ignore = [
123123
"RUF015",
124124
"SIM108",
125125
"UP038", # https://github.com/astral-sh/ruff/issues/7871
126+
"C402", #allow dict comprehension as {}
126127
]
127128

128129
[tool.ruff.lint.extend-per-file-ignores]

src/blosc2/lazyexpr.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,11 @@ def visit_Call(self, node):
630630
return set(visitor.operands)
631631

632632

633-
def conserve_functions(expression, operands_old, operands_new): # noqa: C901
633+
def conserve_functions( # noqa: C901
634+
expression: str,
635+
operands_old: dict[str, blosc2.NDArray | blosc2.LazyExpr],
636+
operands_new: dict[str, blosc2.NDArray | blosc2.LazyExpr],
637+
) -> tuple(str, dict[str, blosc2.NDArray]):
634638
"""
635639
Given an expression in string form, return its operands.
636640
@@ -654,12 +658,22 @@ def conserve_functions(expression, operands_old, operands_new): # noqa: C901
654658
"""
655659

656660
operand_to_key = {id(v): k for k, v in operands_new.items()}
657-
# for k, v in operands_old.items(): # extend operands_to_key with old operands
658-
# try:
659-
# operand_to_key[id(v)]
660-
# except KeyError:
661-
# operand_to_key[id(v)] = k
662-
# operands_new[k] = v
661+
for k, v in operands_old.items(): # extend operands_to_key with old operands
662+
if isinstance(
663+
v, blosc2.LazyExpr
664+
): # unroll operands in LazyExpr (only necessary when have reduced a lazyexpr)
665+
d = v.operands
666+
else:
667+
d = {k: v}
668+
for newk, newv in d.items():
669+
try:
670+
operand_to_key[id(newv)]
671+
except KeyError:
672+
newk = (
673+
f"o{len(operands_new)}" if newk in operands_new else newk
674+
) # possible that names coincide
675+
operand_to_key[id(newv)] = newk
676+
operands_new[newk] = newv
663677

664678
class OperandVisitor(ast.NodeVisitor):
665679
def __init__(self):
@@ -687,12 +701,16 @@ def visit_Name(self, node):
687701
elif node.id not in dtype_symbols:
688702
localop = operands_old[node.id]
689703
if isinstance(localop, blosc2.LazyExpr):
704+
newexpr = localop.expression
690705
for (
691-
_,
706+
opname,
692707
v,
693708
) in localop.operands.items(): # expression operands already in terms of basic operands
694-
_ = self.update_func(v)
695-
node.id = localop.expression
709+
newopname = self.update_func(v)
710+
newexpr = re.sub(
711+
rf"(?<=\s){opname}|(?<=\(){opname}", newopname, newexpr
712+
) # replace with newopname
713+
node.id = newexpr
696714
else:
697715
node.id = self.update_func(localop)
698716
else:
@@ -709,7 +727,8 @@ def visit_Call(self, node):
709727
tree = ast.parse(expression)
710728
visitor = OperandVisitor()
711729
visitor.visit(tree)
712-
return ast.unparse(tree), visitor.operands
730+
newexpression, newoperands = ast.unparse(tree), visitor.operands
731+
return newexpression, newoperands
713732

714733

715734
class TransformNumpyCalls(ast.NodeTransformer):

0 commit comments

Comments
 (0)