diff --git a/codeflash/verification/instrument_codeflash_capture.py b/codeflash/verification/instrument_codeflash_capture.py index d1f9816dc..adcb66ef8 100644 --- a/codeflash/verification/instrument_codeflash_capture.py +++ b/codeflash/verification/instrument_codeflash_capture.py @@ -92,6 +92,27 @@ def __init__( self.tests_root = tests_root self.inserted_decorator = False + # Precompute decorator components to avoid reconstructing on every node visit + # Only the `function_name` field changes per class + self._base_decorator_keywords = [ + ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)), + ast.keyword(arg="tests_root", value=ast.Constant(value=self.tests_root.as_posix())), + ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)), + ] + self._base_decorator_func = ast.Name(id="codeflash_capture", ctx=ast.Load()) + + # Preconstruct starred/kwargs for super init injection for perf + self._super_starred = ast.Starred(value=ast.Name(id="args", ctx=ast.Load())) + self._super_kwarg = ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load())) + self._super_func = ast.Attribute( + value=ast.Call(func=ast.Name(id="super", ctx=ast.Load()), args=[], keywords=[]), + attr="__init__", + ctx=ast.Load(), + ) + self._init_vararg = ast.arg(arg="args") + self._init_kwarg = ast.arg(arg="kwargs") + self._init_self_arg = ast.arg(arg="self", annotation=None) + def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom: # Check if our import already exists if node.module == "codeflash.verification.codeflash_capture" and any( @@ -114,21 +135,18 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: if node.name not in self.target_classes: return node - # Look for __init__ method has_init = False - - # Create the decorator + # Build decorator node ONCE for each class, not per loop iteration decorator = ast.Call( - func=ast.Name(id="codeflash_capture", ctx=ast.Load()), + func=self._base_decorator_func, args=[], keywords=[ ast.keyword(arg="function_name", value=ast.Constant(value=f"{node.name}.__init__")), - ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)), - ast.keyword(arg="tests_root", value=ast.Constant(value=self.tests_root.as_posix())), - ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)), + *self._base_decorator_keywords, ], ) + # Only scan node.body once for both __init__ and decorator check for item in node.body: if ( isinstance(item, ast.FunctionDef) @@ -139,35 +157,28 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: ): has_init = True - # Add decorator at the start of the list if not already present - if not any( - isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "codeflash_capture" - for d in item.decorator_list - ): + # Check for existing decorator in-place, stop after finding one + for d in item.decorator_list: + if isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "codeflash_capture": + break + else: + # No decorator found item.decorator_list.insert(0, decorator) self.inserted_decorator = True if not has_init: - # Create super().__init__(*args, **kwargs) call + # Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments) super_call = ast.Expr( - value=ast.Call( - func=ast.Attribute( - value=ast.Call(func=ast.Name(id="super", ctx=ast.Load()), args=[], keywords=[]), - attr="__init__", - ctx=ast.Load(), - ), - args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()))], - keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))], - ) + value=ast.Call(func=self._super_func, args=[self._super_starred], keywords=[self._super_kwarg]) ) - # Create function arguments: self, *args, **kwargs + # Create function arguments: self, *args, **kwargs (reuse arg nodes) arguments = ast.arguments( posonlyargs=[], - args=[ast.arg(arg="self", annotation=None)], - vararg=ast.arg(arg="args"), + args=[self._init_self_arg], + vararg=self._init_vararg, kwonlyargs=[], kw_defaults=[], - kwarg=ast.arg(arg="kwargs"), + kwarg=self._init_kwarg, defaults=[], )