Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 37 additions & 26 deletions codeflash/verification/instrument_codeflash_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,27 @@ def __init__(
self.tests_root = tests_root
self.inserted_decorator = False

# Precompute decorator components to avoid reconstructing on every node visit
# Only the `function_name` field changes per class
self._base_decorator_keywords = [
ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)),
ast.keyword(arg="tests_root", value=ast.Constant(value=self.tests_root.as_posix())),
ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)),
]
self._base_decorator_func = ast.Name(id="codeflash_capture", ctx=ast.Load())

# Preconstruct starred/kwargs for super init injection for perf
self._super_starred = ast.Starred(value=ast.Name(id="args", ctx=ast.Load()))
self._super_kwarg = ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))
self._super_func = ast.Attribute(
value=ast.Call(func=ast.Name(id="super", ctx=ast.Load()), args=[], keywords=[]),
attr="__init__",
ctx=ast.Load(),
)
self._init_vararg = ast.arg(arg="args")
self._init_kwarg = ast.arg(arg="kwargs")
self._init_self_arg = ast.arg(arg="self", annotation=None)

def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
# Check if our import already exists
if node.module == "codeflash.verification.codeflash_capture" and any(
Expand All @@ -114,21 +135,18 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
if node.name not in self.target_classes:
return node

# Look for __init__ method
has_init = False

# Create the decorator
# Build decorator node ONCE for each class, not per loop iteration
decorator = ast.Call(
func=ast.Name(id="codeflash_capture", ctx=ast.Load()),
func=self._base_decorator_func,
args=[],
keywords=[
ast.keyword(arg="function_name", value=ast.Constant(value=f"{node.name}.__init__")),
ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)),
ast.keyword(arg="tests_root", value=ast.Constant(value=self.tests_root.as_posix())),
ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)),
*self._base_decorator_keywords,
],
)

# Only scan node.body once for both __init__ and decorator check
for item in node.body:
if (
isinstance(item, ast.FunctionDef)
Expand All @@ -139,35 +157,28 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
):
has_init = True

# Add decorator at the start of the list if not already present
if not any(
isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "codeflash_capture"
for d in item.decorator_list
):
# Check for existing decorator in-place, stop after finding one
for d in item.decorator_list:
if isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "codeflash_capture":
break
else:
# No decorator found
item.decorator_list.insert(0, decorator)
self.inserted_decorator = True

if not has_init:
# Create super().__init__(*args, **kwargs) call
# Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments)
super_call = ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Call(func=ast.Name(id="super", ctx=ast.Load()), args=[], keywords=[]),
attr="__init__",
ctx=ast.Load(),
),
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()))],
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
)
value=ast.Call(func=self._super_func, args=[self._super_starred], keywords=[self._super_kwarg])
)
# Create function arguments: self, *args, **kwargs
# Create function arguments: self, *args, **kwargs (reuse arg nodes)
arguments = ast.arguments(
posonlyargs=[],
args=[ast.arg(arg="self", annotation=None)],
vararg=ast.arg(arg="args"),
args=[self._init_self_arg],
vararg=self._init_vararg,
kwonlyargs=[],
kw_defaults=[],
kwarg=ast.arg(arg="kwargs"),
kwarg=self._init_kwarg,
defaults=[],
)

Expand Down
Loading