@@ -630,7 +630,11 @@ def visit_Call(self, node):
630630 return set (visitor .operands )
631631
632632
633- def conserve_functions (expression , operands_old , operands_new ): # noqa: C901
633+ def conserve_functions ( # noqa: C901
634+ expression : str ,
635+ operands_old : dict [str , blosc2 .NDArray | blosc2 .LazyExpr ],
636+ operands_new : dict [str , blosc2 .NDArray | blosc2 .LazyExpr ],
637+ ) -> tuple (str , dict [str , blosc2 .NDArray ]):
634638 """
635639 Given an expression in string form, return its operands.
636640
@@ -654,12 +658,22 @@ def conserve_functions(expression, operands_old, operands_new): # noqa: C901
654658 """
655659
656660 operand_to_key = {id (v ): k for k , v in operands_new .items ()}
657- # for k, v in operands_old.items(): # extend operands_to_key with old operands
658- # try:
659- # operand_to_key[id(v)]
660- # except KeyError:
661- # operand_to_key[id(v)] = k
662- # operands_new[k] = v
661+ for k , v in operands_old .items (): # extend operands_to_key with old operands
662+ if isinstance (
663+ v , blosc2 .LazyExpr
664+ ): # unroll operands in LazyExpr (only necessary when have reduced a lazyexpr)
665+ d = v .operands
666+ else :
667+ d = {k : v }
668+ for newk , newv in d .items ():
669+ try :
670+ operand_to_key [id (newv )]
671+ except KeyError :
672+ newk = (
673+ f"o{ len (operands_new )} " if newk in operands_new else newk
674+ ) # possible that names coincide
675+ operand_to_key [id (newv )] = newk
676+ operands_new [newk ] = newv
663677
664678 class OperandVisitor (ast .NodeVisitor ):
665679 def __init__ (self ):
@@ -687,12 +701,16 @@ def visit_Name(self, node):
687701 elif node .id not in dtype_symbols :
688702 localop = operands_old [node .id ]
689703 if isinstance (localop , blosc2 .LazyExpr ):
704+ newexpr = localop .expression
690705 for (
691- _ ,
706+ opname ,
692707 v ,
693708 ) in localop .operands .items (): # expression operands already in terms of basic operands
694- _ = self .update_func (v )
695- node .id = localop .expression
709+ newopname = self .update_func (v )
710+ newexpr = re .sub (
711+ rf"(?<=\s){ opname } |(?<=\(){ opname } " , newopname , newexpr
712+ ) # replace with newopname
713+ node .id = newexpr
696714 else :
697715 node .id = self .update_func (localop )
698716 else :
@@ -709,7 +727,8 @@ def visit_Call(self, node):
709727 tree = ast .parse (expression )
710728 visitor = OperandVisitor ()
711729 visitor .visit (tree )
712- return ast .unparse (tree ), visitor .operands
730+ newexpression , newoperands = ast .unparse (tree ), visitor .operands
731+ return newexpression , newoperands
713732
714733
715734class TransformNumpyCalls (ast .NodeTransformer ):
0 commit comments