@@ -630,69 +630,86 @@ def visit_Call(self, node):
630630 return set (visitor .operands )
631631
632632
633- #
634- # def cons_functions(
635- # expression: str,
636- # operands_old: dict[str : blosc2.ndarray | blosc2.LazyExpr],
637- # operands_new: dict[str : blosc2.ndarray | blosc2.LazyExpr],
638- # ) -> tuple(str, dict[str : blosc2.ndarray.NDArray | blosc2.LazyExpr]):
639- # operand_to_key = {id(v): k for k, v in operands_new.items()}
640- # for k, v in operands_old.items(): # extend operands_to_key with old operands
641- # try:
642- # operand_to_key[id(v)]
643- # except KeyError:
644- # operand_to_key[id(v)] = k
645- # operands_new[k] = v
646- #
647- # class OperandVisitor(ast.NodeVisitor):
648- # def __init__(self):
649- # self.operandmap = {}
650- # self.operands = {}
651- # self.opcounter = 0
652- # self.function_names = set()
653- #
654- # def update_func(self, localop):
655- # k = operand_to_key[id(localop)]
656- # if k not in self.operandmap:
657- # newkey = f"o{self.opcounter}"
658- # self.operands[newkey] = operands_new[k]
659- # self.operandmap[k] = newkey
660- # self.opcounter += 1
661- # return newkey
662- # else:
663- # return self.operandmap[k]
664- #
665- # def visit_Name(self, node):
666- # if node.id == "np": # Skip NumPy namespace (e.g. np.int8, which will be treated separately)
667- # return
668- # if node.id in self.function_names: # Skip function names
669- # return
670- # elif node.id not in dtype_symbols:
671- # localop = operands_old[node.id]
672- # if isinstance(localop, blosc2.LazyExpr):
673- # for (
674- # _,
675- # v,
676- # ) in localop.operands.items(): # expression operands already in terms of basic operands
677- # _ = self.update_func(v)
678- # node.id = localop.expression
679- # else:
680- # node.id = self.update_func(localop)
681- # else:
682- # pass
683- # self.generic_visit(node)
684- #
685- # def visit_Call(self, node):
686- # if isinstance(
687- # node.func, ast.Name
688- # ): # visits Call first, then Name, so don't increment operandcounter yet
689- # self.function_names.add(node.func.id)
690- # self.generic_visit(node)
691- #
692- # tree = ast.parse(expression)
693- # visitor = OperandVisitor()
694- # visitor.visit(tree)
695- # return ast.unparse(tree), visitor.operands
633+ def conserve_functions (expression , operands_old , operands_new ): # noqa: C901
634+ """
635+ Given an expression in string form, return its operands.
636+
637+ Parameters
638+ ----------
639+ expression : str
640+ The expression in string form.
641+
642+ operands_old: dict[str : blosc2.ndarray | blosc2.LazyExpr]
643+ Dict of operands from expression prior to eval.
644+
645+ operands_new: dict[str : blosc2.ndarray | blosc2.LazyExpr]
646+ Dict of operands from expression after eval.
647+ Returns
648+ -------
649+ newexpression
650+ A modified string expression with the functions/constructors conserved and
651+ true operands rebased and written in o- notation.
652+ newoperands
653+ Dict of the set of rebased operands.
654+ """
655+
656+ operand_to_key = {id (v ): k for k , v in operands_new .items ()}
657+ # for k, v in operands_old.items(): # extend operands_to_key with old operands
658+ # try:
659+ # operand_to_key[id(v)]
660+ # except KeyError:
661+ # operand_to_key[id(v)] = k
662+ # operands_new[k] = v
663+
664+ class OperandVisitor (ast .NodeVisitor ):
665+ def __init__ (self ):
666+ self .operandmap = {}
667+ self .operands = {}
668+ self .opcounter = 0
669+ self .function_names = set ()
670+
671+ def update_func (self , localop ):
672+ k = operand_to_key [id (localop )]
673+ if k not in self .operandmap :
674+ newkey = f"o{ self .opcounter } "
675+ self .operands [newkey ] = operands_new [k ]
676+ self .operandmap [k ] = newkey
677+ self .opcounter += 1
678+ return newkey
679+ else :
680+ return self .operandmap [k ]
681+
682+ def visit_Name (self , node ):
683+ if node .id == "np" : # Skip NumPy namespace (e.g. np.int8, which will be treated separately)
684+ return
685+ if node .id in self .function_names : # Skip function names
686+ return
687+ elif node .id not in dtype_symbols :
688+ localop = operands_old [node .id ]
689+ if isinstance (localop , blosc2 .LazyExpr ):
690+ for (
691+ _ ,
692+ v ,
693+ ) in localop .operands .items (): # expression operands already in terms of basic operands
694+ _ = self .update_func (v )
695+ node .id = localop .expression
696+ else :
697+ node .id = self .update_func (localop )
698+ else :
699+ pass
700+ self .generic_visit (node )
701+
702+ def visit_Call (self , node ):
703+ if isinstance (
704+ node .func , ast .Name
705+ ): # visits Call first, then Name, so don't increment operandcounter yet
706+ self .function_names .add (node .func .id )
707+ self .generic_visit (node )
708+
709+ tree = ast .parse (expression )
710+ visitor = OperandVisitor ()
711+ visitor .visit (tree )
712+ return ast .unparse (tree ), visitor .operands
696713
697714
698715class TransformNumpyCalls (ast .NodeTransformer ):
@@ -2745,12 +2762,12 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
27452762 if isinstance (new_expr , blosc2 .LazyExpr ):
27462763 # DO NOT restore the original expression and operands
27472764 # Instead rebase operands and restore only constructors
2748- # expression_, operands_ = cons_functions (
2749- # _expression, _operands, new_expr.operands | local_vars
2750- # )
2751- # new_expr.expression = f"({expression_})" # force parenthesis
2765+ expression_ , operands_ = conserve_functions (
2766+ _expression , _operands , new_expr .operands | local_vars
2767+ )
2768+ new_expr .expression = f"({ expression_ } )" # force parenthesis
27522769 new_expr .expression_tosave = expression
2753- # new_expr.operands = operands_
2770+ new_expr .operands = operands_
27542771 new_expr .operands_tosave = operands
27552772 else :
27562773 # An immediate evaluation happened (e.g. all operands are numpy arrays)
0 commit comments