Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions codeflash/languages/python/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,47 @@ def _get_dataclass_config(class_node: ast.ClassDef, import_aliases: dict[str, st
return False, False, False


_ATTRS_NAMESPACES = frozenset({"attrs", "attr"})
_ATTRS_DECORATOR_NAMES = frozenset({"define", "mutable", "frozen", "s", "attrs"})


def _resolve_decorator_name(expr_name: str, import_aliases: dict[str, str]) -> str:
resolved = import_aliases.get(expr_name)
if resolved is not None:
return resolved
parts = expr_name.split(".")
if len(parts) >= 2:
root_resolved = import_aliases.get(parts[0])
if root_resolved is not None:
parts[0] = root_resolved
return ".".join(parts)
return expr_name


def _get_attrs_config(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> tuple[bool, bool, bool]:
for decorator in class_node.decorator_list:
expr_name = _get_expr_name(decorator)
if expr_name is None:
continue
resolved = _resolve_decorator_name(expr_name, import_aliases)
parts = resolved.split(".")
if len(parts) < 2 or parts[-2] not in _ATTRS_NAMESPACES or parts[-1] not in _ATTRS_DECORATOR_NAMES:
continue
init_enabled = True
kw_only = False
if isinstance(decorator, ast.Call):
for keyword in decorator.keywords:
literal_value = _bool_literal(keyword.value)
if literal_value is None:
continue
if keyword.arg == "init":
init_enabled = literal_value
elif keyword.arg == "kw_only":
kw_only = literal_value
return True, init_enabled, kw_only
return False, False, False


def _is_classvar_annotation(annotation: ast.expr, import_aliases: dict[str, str]) -> bool:
annotation_root = annotation.value if isinstance(annotation, ast.Subscript) else annotation
return _expr_matches_name(annotation_root, import_aliases, "ClassVar")
Expand All @@ -885,10 +926,13 @@ def _class_has_explicit_init(class_node: ast.ClassDef) -> bool:

def _collect_synthetic_constructor_type_names(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> set[str]:
is_dataclass, dataclass_init_enabled, _ = _get_dataclass_config(class_node, import_aliases)
if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass:
is_attrs, attrs_init_enabled, _ = _get_attrs_config(class_node, import_aliases)
if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass and not is_attrs:
return set()
if is_dataclass and not dataclass_init_enabled:
return set()
if is_attrs and not attrs_init_enabled:
return set()

names = set[str]()
for item in class_node.body:
Expand Down Expand Up @@ -939,9 +983,9 @@ def _extract_synthetic_init_parameters(
kw_only = literal_value
elif keyword.arg == "default":
default_value = _get_node_source(keyword.value, module_source)
elif keyword.arg == "default_factory":
# Default factories still imply an optional constructor parameter, but
# the generated __init__ does not use the field() call directly.
elif keyword.arg in {"default_factory", "factory"}:
# Default factories (dataclass default_factory= / attrs factory=) still imply
# an optional constructor parameter.
default_value = "..."
else:
default_value = _get_node_source(item.value, module_source)
Expand All @@ -960,13 +1004,17 @@ def _build_synthetic_init_stub(
) -> str | None:
is_namedtuple = _is_namedtuple_class(class_node, import_aliases)
is_dataclass, dataclass_init_enabled, dataclass_kw_only = _get_dataclass_config(class_node, import_aliases)
if not is_namedtuple and not is_dataclass:
is_attrs, attrs_init_enabled, attrs_kw_only = _get_attrs_config(class_node, import_aliases)
if not is_namedtuple and not is_dataclass and not is_attrs:
return None
if is_dataclass and not dataclass_init_enabled:
return None
if is_attrs and not attrs_init_enabled:
return None

kw_only_by_default = dataclass_kw_only or attrs_kw_only
parameters = _extract_synthetic_init_parameters(
class_node, module_source, import_aliases, kw_only_by_default=dataclass_kw_only
class_node, module_source, import_aliases, kw_only_by_default=kw_only_by_default
)
if not parameters:
return None
Expand Down
103 changes: 103 additions & 0 deletions codeflash/languages/python/instrument_codeflash_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.formatter import sort_imports
from codeflash.languages.python.context.code_context_extractor import _ATTRS_DECORATOR_NAMES, _ATTRS_NAMESPACES

if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
self.has_import = False
self.tests_root = tests_root
self.inserted_decorator = False
self._attrs_classes_to_patch: dict[str, ast.Call] = {}

# Precompute decorator components to avoid reconstructing on every node visit
# Only the `function_name` field changes per class
Expand Down Expand Up @@ -118,6 +120,18 @@ def __init__(
defaults=[],
)

# Pre-build reusable AST nodes for _build_attrs_patch_block
self._load_ctx = ast.Load()
self._store_ctx = ast.Store()
self._args_name_load = ast.Name(id="args", ctx=self._load_ctx)
self._kwargs_name_load = ast.Name(id="kwargs", ctx=self._load_ctx)
self._self_arg_node = ast.arg(arg="self")
self._args_arg_node = ast.arg(arg="args")
self._kwargs_arg_node = ast.arg(arg="kwargs")
self._self_name_load = ast.Name(id="self", ctx=self._load_ctx)
self._starred_args = ast.Starred(value=self._args_name_load, ctx=self._load_ctx)
self._kwargs_keyword = ast.keyword(arg=None, value=self._kwargs_name_load)

def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
# Check if our import already exists
if node.module == "codeflash.verification.codeflash_capture" and any(
Expand All @@ -128,6 +142,17 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:

def visit_Module(self, node: ast.Module) -> ast.Module:
self.generic_visit(node)

# Insert module-level monkey-patch wrappers for attrs classes immediately after their
# class definitions. We do this before inserting the import so indices stay stable.
if self._attrs_classes_to_patch:
new_body: list[ast.stmt] = []
for stmt in node.body:
new_body.append(stmt)
if isinstance(stmt, ast.ClassDef) and stmt.name in self._attrs_classes_to_patch:
new_body.extend(self._build_attrs_patch_block(stmt.name, self._attrs_classes_to_patch[stmt.name]))
node.body = new_body

# Add import statement
if not self.has_import and self.inserted_decorator:
import_stmt = ast.parse("from codeflash.verification.codeflash_capture import codeflash_capture").body[0]
Expand Down Expand Up @@ -188,6 +213,33 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
if base_name is not None and base_name.endswith("NamedTuple"):
return node

# Attrs classes — their __init__ is auto-generated by the decorator at class creation
# time. With slots=True (the default for @attrs.define), attrs creates a brand-new class
# object, so the __class__ cell baked into a synthesised
# `super().__init__(*args, **kwargs)` still refers to the *original* class while `self`
# is already an instance of the *new* slots class, causing a TypeError.
# We therefore skip modifying the class body and instead emit a module-level
# monkey-patch block after the class (handled in visit_Module):
# _codeflash_orig_ClassName_init = ClassName.__init__
# def _codeflash_patched_ClassName_init(self, *a, **kw): ...
# ClassName.__init__ = codeflash_capture(...)(_codeflash_patched_ClassName_init)
for dec in node.decorator_list:
dec_name = self._expr_name(dec)
if dec_name is not None:
parts = dec_name.split(".")
if len(parts) >= 2 and parts[-2] in _ATTRS_NAMESPACES and parts[-1] in _ATTRS_DECORATOR_NAMES:
if isinstance(dec, ast.Call):
for kw in dec.keywords:
if (
kw.arg == "init"
and isinstance(kw.value, ast.Constant)
and kw.value.value is False
):
return node
self._attrs_classes_to_patch[node.name] = decorator
self.inserted_decorator = True
return node

# Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments)
super_call = self._super_call_expr
# Create the complete function using prebuilt arguments/body but attach the class-specific decorator
Expand All @@ -202,6 +254,57 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:

return node

def _build_attrs_patch_block(self, class_name: str, decorator: ast.Call) -> list[ast.stmt]:
orig_name = f"_codeflash_orig_{class_name}_init"
patched_name = f"_codeflash_patched_{class_name}_init"

# _codeflash_orig_ClassName_init = ClassName.__init__

# Create class name nodes once
class_name_load = ast.Name(id=class_name, ctx=self._load_ctx)

# _codeflash_orig_ClassName_init = ClassName.__init__
save_orig = ast.Assign(
targets=[ast.Name(id=orig_name, ctx=self._store_ctx)],
value=ast.Attribute(value=class_name_load, attr="__init__", ctx=self._load_ctx),
)

# def _codeflash_patched_ClassName_init(self, *args, **kwargs):
# return _codeflash_orig_ClassName_init(self, *args, **kwargs)
patched_func = ast.FunctionDef(
name=patched_name,
args=ast.arguments(
posonlyargs=[],
args=[self._self_arg_node],
vararg=self._args_arg_node,
kwonlyargs=[],
kw_defaults=[],
kwarg=self._kwargs_arg_node,
defaults=[],
),
body=[
ast.Return(
value=ast.Call(
func=ast.Name(id=orig_name, ctx=self._load_ctx),
args=[self._self_name_load, self._starred_args],
keywords=[self._kwargs_keyword],
)
)
],
decorator_list=[],
returns=None,
)

# ClassName.__init__ = codeflash_capture(...)(_codeflash_patched_ClassName_init)
assign_patched = ast.Assign(
targets=[
ast.Attribute(value=ast.Name(id=class_name, ctx=self._load_ctx), attr="__init__", ctx=self._store_ctx)
],
value=ast.Call(func=decorator, args=[ast.Name(id=patched_name, ctx=self._load_ctx)], keywords=[]),
)

return [save_orig, patched_func, assign_patched]

def _expr_name(self, node: ast.AST) -> str | None:
if isinstance(node, ast.Name):
return node.id
Expand Down
62 changes: 62 additions & 0 deletions tests/test_code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5013,6 +5013,68 @@ def process(cfg: ChildConfig) -> str:
assert "qualified_name: str" in combined


def test_extract_init_stub_attrs_define(tmp_path: Path) -> None:
"""extract_init_stub_from_class produces a synthetic __init__ stub for @attrs.define classes."""
source = """
import attrs
from attrs.validators import instance_of

@attrs.define(frozen=True)
class ImportCST:
module: str = attrs.field(converter=str.lower)
name: str = attrs.field(validator=[instance_of(str)])
as_name: str = attrs.field(validator=[instance_of(str)])

def to_str(self) -> str:
return f"from {self.module} import {self.name}"
"""
expected = "class ImportCST:\n def __init__(self, module: str, name: str, as_name: str):\n ..."
tree = ast.parse(source)
stub = extract_init_stub_from_class("ImportCST", source, tree)
assert stub == expected


def test_extract_init_stub_attrs_factory_fields(tmp_path: Path) -> None:
"""Fields using attrs factory= keyword should appear as optional (= ...) in the stub."""
source = """
import attrs

@attrs.define
class ClassCST:
name: str = attrs.field()
methods: list = attrs.field(factory=list)
imports: set = attrs.field(factory=set)

def compute(self) -> int:
return len(self.methods)
"""
expected = "class ClassCST:\n def __init__(self, name: str, methods: list = ..., imports: set = ...):\n ..."
tree = ast.parse(source)
stub = extract_init_stub_from_class("ClassCST", source, tree)
assert stub == expected


def test_extract_init_stub_attrs_init_disabled(tmp_path: Path) -> None:
"""When @attrs.define(init=False) but with explicit __init__, the explicit body is returned."""
source = """
import attrs

@attrs.define(init=False)
class NoAutoInit:
x: int = attrs.field()

def __init__(self, x: int):
self.x = x * 2

def get(self) -> int:
return self.x
"""
expected = "class NoAutoInit:\n def __init__(self, x: int):\n self.x = x * 2"
tree = ast.parse(source)
stub = extract_init_stub_from_class("NoAutoInit", source, tree)
assert stub == expected


def test_enrich_testgen_context_third_party_uses_stubs(tmp_path: Path) -> None:
"""Third-party classes should produce compact __init__ stubs, not full class source."""
# Use a real third-party package (pydantic) so jedi can actually resolve it
Expand Down
Loading
Loading