diff --git a/codeflash/code_utils/git_utils.py b/codeflash/code_utils/git_utils.py index 039c4987a..55b3f99ca 100644 --- a/codeflash/code_utils/git_utils.py +++ b/codeflash/code_utils/git_utils.py @@ -13,7 +13,6 @@ from unidiff import PatchSet from codeflash.cli_cmds.console import logger -from codeflash.languages.registry import get_supported_extensions if TYPE_CHECKING: from git import Repo @@ -39,7 +38,6 @@ def get_git_diff( uni_diff_text = repository.git.diff( commit.hexsha + "^1", commit.hexsha, ignore_blank_lines=True, ignore_space_at_eol=True ) - supported_extensions = set(get_supported_extensions()) patch_set = PatchSet(StringIO(uni_diff_text)) change_list: dict[str, list[int]] = {} # list of changes for patched_file in patch_set: diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index bfbf02fc4..497657109 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -861,6 +861,47 @@ def _get_dataclass_config(class_node: ast.ClassDef, import_aliases: dict[str, st return False, False, False +_ATTRS_NAMESPACES = frozenset({"attrs", "attr"}) +_ATTRS_DECORATOR_NAMES = frozenset({"define", "mutable", "frozen", "s", "attrs"}) + + +def _resolve_decorator_name(expr_name: str, import_aliases: dict[str, str]) -> str: + resolved = import_aliases.get(expr_name) + if resolved is not None: + return resolved + first_part, sep, rest = expr_name.partition(".") + if sep: + root_resolved = import_aliases.get(first_part) + if root_resolved is not None: + return f"{root_resolved}.{rest}" + + return expr_name + + +def _get_attrs_config(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> tuple[bool, bool, bool]: + for decorator in class_node.decorator_list: + expr_name = _get_expr_name(decorator) + if expr_name is None: + continue + resolved = _resolve_decorator_name(expr_name, import_aliases) + parts = resolved.split(".") + if len(parts) < 2 or parts[-2] not in _ATTRS_NAMESPACES or parts[-1] not in _ATTRS_DECORATOR_NAMES: + continue + init_enabled = True + kw_only = False + if isinstance(decorator, ast.Call): + for keyword in decorator.keywords: + literal_value = _bool_literal(keyword.value) + if literal_value is None: + continue + if keyword.arg == "init": + init_enabled = literal_value + elif keyword.arg == "kw_only": + kw_only = literal_value + return True, init_enabled, kw_only + return False, False, False + + def _is_classvar_annotation(annotation: ast.expr, import_aliases: dict[str, str]) -> bool: annotation_root = annotation.value if isinstance(annotation, ast.Subscript) else annotation return _expr_matches_name(annotation_root, import_aliases, "ClassVar") @@ -885,10 +926,13 @@ def _class_has_explicit_init(class_node: ast.ClassDef) -> bool: def _collect_synthetic_constructor_type_names(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> set[str]: is_dataclass, dataclass_init_enabled, _ = _get_dataclass_config(class_node, import_aliases) - if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass: + is_attrs, attrs_init_enabled, _ = _get_attrs_config(class_node, import_aliases) + if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass and not is_attrs: return set() if is_dataclass and not dataclass_init_enabled: return set() + if is_attrs and not attrs_init_enabled: + return set() names = set[str]() for item in class_node.body: @@ -939,9 +983,9 @@ def _extract_synthetic_init_parameters( kw_only = literal_value elif keyword.arg == "default": default_value = _get_node_source(keyword.value, module_source) - elif keyword.arg == "default_factory": - # Default factories still imply an optional constructor parameter, but - # the generated __init__ does not use the field() call directly. + elif keyword.arg in {"default_factory", "factory"}: + # Default factories (dataclass default_factory= / attrs factory=) still imply + # an optional constructor parameter. default_value = "..." else: default_value = _get_node_source(item.value, module_source) @@ -960,13 +1004,17 @@ def _build_synthetic_init_stub( ) -> str | None: is_namedtuple = _is_namedtuple_class(class_node, import_aliases) is_dataclass, dataclass_init_enabled, dataclass_kw_only = _get_dataclass_config(class_node, import_aliases) - if not is_namedtuple and not is_dataclass: + is_attrs, attrs_init_enabled, attrs_kw_only = _get_attrs_config(class_node, import_aliases) + if not is_namedtuple and not is_dataclass and not is_attrs: return None if is_dataclass and not dataclass_init_enabled: return None + if is_attrs and not attrs_init_enabled: + return None + kw_only_by_default = dataclass_kw_only or attrs_kw_only parameters = _extract_synthetic_init_parameters( - class_node, module_source, import_aliases, kw_only_by_default=dataclass_kw_only + class_node, module_source, import_aliases, kw_only_by_default=kw_only_by_default ) if not parameters: return None diff --git a/codeflash/languages/python/instrument_codeflash_capture.py b/codeflash/languages/python/instrument_codeflash_capture.py index 3b77a1f53..cda06b1dc 100644 --- a/codeflash/languages/python/instrument_codeflash_capture.py +++ b/codeflash/languages/python/instrument_codeflash_capture.py @@ -2,10 +2,11 @@ import ast from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.formatter import sort_imports +from codeflash.languages.python.context.code_context_extractor import _ATTRS_DECORATOR_NAMES, _ATTRS_NAMESPACES if TYPE_CHECKING: from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -80,6 +81,7 @@ def __init__( self.has_import = False self.tests_root = tests_root self.inserted_decorator = False + self._attrs_classes_to_patch: dict[str, ast.Call] = {} # Precompute decorator components to avoid reconstructing on every node visit # Only the `function_name` field changes per class @@ -118,6 +120,21 @@ def __init__( defaults=[], ) + # Pre-build reusable AST nodes for _build_attrs_patch_block + self._load_ctx = ast.Load() + self._store_ctx = ast.Store() + self._args_name_load = ast.Name(id="args", ctx=self._load_ctx) + self._kwargs_name_load = ast.Name(id="kwargs", ctx=self._load_ctx) + self._self_arg_node = ast.arg(arg="self") + self._args_arg_node = ast.arg(arg="args") + self._kwargs_arg_node = ast.arg(arg="kwargs") + self._self_name_load = ast.Name(id="self", ctx=self._load_ctx) + self._starred_args = ast.Starred(value=self._args_name_load, ctx=self._load_ctx) + self._kwargs_keyword = ast.keyword(arg=None, value=self._kwargs_name_load) + + # Pre-parse the import statement to avoid repeated parsing in visit_Module + self._import_stmt = ast.parse("from codeflash.verification.codeflash_capture import codeflash_capture").body[0] + def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom: # Check if our import already exists if node.module == "codeflash.verification.codeflash_capture" and any( @@ -128,10 +145,20 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom: def visit_Module(self, node: ast.Module) -> ast.Module: self.generic_visit(node) + + # Insert module-level monkey-patch wrappers for attrs classes immediately after their + # class definitions. We do this before inserting the import so indices stay stable. + if self._attrs_classes_to_patch: + new_body: list[ast.stmt] = [] + for stmt in node.body: + new_body.append(stmt) + if isinstance(stmt, ast.ClassDef) and stmt.name in self._attrs_classes_to_patch: + new_body.extend(self._build_attrs_patch_block(stmt.name, self._attrs_classes_to_patch[stmt.name])) + node.body = new_body + # Add import statement if not self.has_import and self.inserted_decorator: - import_stmt = ast.parse("from codeflash.verification.codeflash_capture import codeflash_capture").body[0] - node.body.insert(0, import_stmt) + node.body.insert(0, self._import_stmt) return node @@ -171,6 +198,8 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: item.decorator_list.insert(0, decorator) self.inserted_decorator = True + break + if not has_init: # Skip dataclasses — their __init__ is auto-generated at class creation time and isn't in the AST. # The synthetic __init__ with super().__init__(*args, **kwargs) overrides it and fails because @@ -181,6 +210,18 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: dec_name = self._expr_name(dec) if dec_name is not None and dec_name.endswith("dataclass"): return node + if dec_name is not None: + parts = dec_name.split(".") + if len(parts) >= 2 and parts[-2] in _ATTRS_NAMESPACES and parts[-1] in _ATTRS_DECORATOR_NAMES: + if isinstance(dec, ast.Call): + for kw in dec.keywords: + if kw.arg == "init" and isinstance(kw.value, ast.Constant) and kw.value.value is False: + return node + self._attrs_classes_to_patch[node.name] = decorator + self.inserted_decorator = True + return node + + # Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments) # Skip NamedTuples — their __init__ is synthesized and cannot be overwritten. for base in node.bases: @@ -202,6 +243,60 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: return node + def _build_attrs_patch_block(self, class_name: str, decorator: ast.Call) -> list[ast.stmt]: + orig_name = f"_codeflash_orig_{class_name}_init" + patched_name = f"_codeflash_patched_{class_name}_init" + + # _codeflash_orig_ClassName_init = ClassName.__init__ + + # Create class name nodes once + class_name_load = ast.Name(id=class_name, ctx=self._load_ctx) + + # _codeflash_orig_ClassName_init = ClassName.__init__ + save_orig = ast.Assign( + targets=[ast.Name(id=orig_name, ctx=self._store_ctx)], + value=ast.Attribute(value=class_name_load, attr="__init__", ctx=self._load_ctx), + ) + + # def _codeflash_patched_ClassName_init(self, *args, **kwargs): + # return _codeflash_orig_ClassName_init(self, *args, **kwargs) + patched_func = ast.FunctionDef( + name=patched_name, + args=ast.arguments( + posonlyargs=[], + args=[self._self_arg_node], + vararg=self._args_arg_node, + kwonlyargs=[], + kw_defaults=[], + kwarg=self._kwargs_arg_node, + defaults=[], + ), + body=cast( + "list[ast.stmt]", + [ + ast.Return( + value=ast.Call( + func=ast.Name(id=orig_name, ctx=self._load_ctx), + args=[self._self_name_load, self._starred_args], + keywords=[self._kwargs_keyword], + ) + ) + ], + ), + decorator_list=cast("list[ast.expr]", []), + returns=None, + ) + + # ClassName.__init__ = codeflash_capture(...)(_codeflash_patched_ClassName_init) + assign_patched = ast.Assign( + targets=[ + ast.Attribute(value=ast.Name(id=class_name, ctx=self._load_ctx), attr="__init__", ctx=self._store_ctx) + ], + value=ast.Call(func=decorator, args=[ast.Name(id=patched_name, ctx=self._load_ctx)], keywords=[]), + ) + + return [save_orig, patched_func, assign_patched] + def _expr_name(self, node: ast.AST) -> str | None: if isinstance(node, ast.Name): return node.id diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index a2b31eb94..ccfa5410d 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -5013,6 +5013,68 @@ def process(cfg: ChildConfig) -> str: assert "qualified_name: str" in combined +def test_extract_init_stub_attrs_define(tmp_path: Path) -> None: + """extract_init_stub_from_class produces a synthetic __init__ stub for @attrs.define classes.""" + source = """ +import attrs +from attrs.validators import instance_of + +@attrs.define(frozen=True) +class ImportCST: + module: str = attrs.field(converter=str.lower) + name: str = attrs.field(validator=[instance_of(str)]) + as_name: str = attrs.field(validator=[instance_of(str)]) + + def to_str(self) -> str: + return f"from {self.module} import {self.name}" +""" + expected = "class ImportCST:\n def __init__(self, module: str, name: str, as_name: str):\n ..." + tree = ast.parse(source) + stub = extract_init_stub_from_class("ImportCST", source, tree) + assert stub == expected + + +def test_extract_init_stub_attrs_factory_fields(tmp_path: Path) -> None: + """Fields using attrs factory= keyword should appear as optional (= ...) in the stub.""" + source = """ +import attrs + +@attrs.define +class ClassCST: + name: str = attrs.field() + methods: list = attrs.field(factory=list) + imports: set = attrs.field(factory=set) + + def compute(self) -> int: + return len(self.methods) +""" + expected = "class ClassCST:\n def __init__(self, name: str, methods: list = ..., imports: set = ...):\n ..." + tree = ast.parse(source) + stub = extract_init_stub_from_class("ClassCST", source, tree) + assert stub == expected + + +def test_extract_init_stub_attrs_init_disabled(tmp_path: Path) -> None: + """When @attrs.define(init=False) but with explicit __init__, the explicit body is returned.""" + source = """ +import attrs + +@attrs.define(init=False) +class NoAutoInit: + x: int = attrs.field() + + def __init__(self, x: int): + self.x = x * 2 + + def get(self) -> int: + return self.x +""" + expected = "class NoAutoInit:\n def __init__(self, x: int):\n self.x = x * 2" + tree = ast.parse(source) + stub = extract_init_stub_from_class("NoAutoInit", source, tree) + assert stub == expected + + def test_enrich_testgen_context_third_party_uses_stubs(tmp_path: Path) -> None: """Third-party classes should produce compact __init__ stubs, not full class source.""" # Use a real third-party package (pydantic) so jedi can actually resolve it diff --git a/tests/test_instrument_codeflash_capture.py b/tests/test_instrument_codeflash_capture.py index 543d50855..4ad0fade1 100644 --- a/tests/test_instrument_codeflash_capture.py +++ b/tests/test_instrument_codeflash_capture.py @@ -2,8 +2,8 @@ from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import FunctionParent from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture +from codeflash.models.models import FunctionParent def test_add_codeflash_capture(): @@ -499,6 +499,184 @@ def display(self): test_path.unlink(missing_ok=True) +def test_attrs_define_patched_via_module_wrapper(): + """@attrs.define classes must NOT get a synthetic body __init__; instead a module-level + monkey-patch block is emitted after the class to avoid the __class__ cell TypeError + that arises when attrs.define(slots=True) replaces the original class object. + """ + original_code = """ +import attrs +from attrs.validators import instance_of + +@attrs.define +class MyAttrsClass: + x: int = attrs.field(validator=[instance_of(int)]) + y: str = attrs.field(default="hello") + + def compute(self): + return self.x +""" + test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve() + expected = f"""import attrs +from attrs.validators import instance_of + +from codeflash.verification.codeflash_capture import codeflash_capture + + +@attrs.define +class MyAttrsClass: + x: int = attrs.field(validator=[instance_of(int)]) + y: str = attrs.field(default='hello') + + def compute(self): + return self.x +_codeflash_orig_MyAttrsClass_init = MyAttrsClass.__init__ + +def _codeflash_patched_MyAttrsClass_init(self, *args, **kwargs): + return _codeflash_orig_MyAttrsClass_init(self, *args, **kwargs) +MyAttrsClass.__init__ = codeflash_capture(function_name='MyAttrsClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)(_codeflash_patched_MyAttrsClass_init) +""" + test_path.write_text(original_code) + + function = FunctionToOptimize( + function_name="compute", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyAttrsClass")] + ) + + try: + instrument_codeflash_capture(function, {}, test_path.parent) + modified_code = test_path.read_text() + assert modified_code.strip() == expected.strip() + finally: + test_path.unlink(missing_ok=True) + + +def test_attrs_define_frozen_patched_via_module_wrapper(): + """@attrs.define(frozen=True) should also be monkey-patched at module level.""" + original_code = """ +import attrs + +@attrs.define(frozen=True) +class FrozenPoint: + x: float = attrs.field() + y: float = attrs.field() + + def distance(self): + return (self.x ** 2 + self.y ** 2) ** 0.5 +""" + test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve() + expected = f"""import attrs + +from codeflash.verification.codeflash_capture import codeflash_capture + + +@attrs.define(frozen=True) +class FrozenPoint: + x: float = attrs.field() + y: float = attrs.field() + + def distance(self): + return (self.x ** 2 + self.y ** 2) ** 0.5 +_codeflash_orig_FrozenPoint_init = FrozenPoint.__init__ + +def _codeflash_patched_FrozenPoint_init(self, *args, **kwargs): + return _codeflash_orig_FrozenPoint_init(self, *args, **kwargs) +FrozenPoint.__init__ = codeflash_capture(function_name='FrozenPoint.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)(_codeflash_patched_FrozenPoint_init) +""" + test_path.write_text(original_code) + + function = FunctionToOptimize( + function_name="distance", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="FrozenPoint")] + ) + + try: + instrument_codeflash_capture(function, {}, test_path.parent) + modified_code = test_path.read_text() + assert modified_code.strip() == expected.strip() + finally: + test_path.unlink(missing_ok=True) + + +def test_attr_s_patched_via_module_wrapper(): + """@attr.s classes should also be monkey-patched at module level.""" + original_code = """ +import attr + +@attr.s +class MyAttrClass: + x: int = attr.ib() + + def display(self): + return self.x +""" + test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve() + expected = f"""import attr + +from codeflash.verification.codeflash_capture import codeflash_capture + + +@attr.s +class MyAttrClass: + x: int = attr.ib() + + def display(self): + return self.x +_codeflash_orig_MyAttrClass_init = MyAttrClass.__init__ + +def _codeflash_patched_MyAttrClass_init(self, *args, **kwargs): + return _codeflash_orig_MyAttrClass_init(self, *args, **kwargs) +MyAttrClass.__init__ = codeflash_capture(function_name='MyAttrClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)(_codeflash_patched_MyAttrClass_init) +""" + test_path.write_text(original_code) + + function = FunctionToOptimize( + function_name="display", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyAttrClass")] + ) + + try: + instrument_codeflash_capture(function, {}, test_path.parent) + modified_code = test_path.read_text() + assert modified_code.strip() == expected.strip() + finally: + test_path.unlink(missing_ok=True) + + +def test_attrs_define_init_false_skipped(): + """@attrs.define(init=False) should NOT be monkey-patched because attrs won't generate an __init__.""" + original_code = """ +import attrs + +@attrs.define(init=False) +class ManualInit: + x: int = attrs.field() + + def compute(self): + return self.x +""" + expected = """import attrs + + +@attrs.define(init=False) +class ManualInit: + x: int = attrs.field() + + def compute(self): + return self.x +""" + test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve() + test_path.write_text(original_code) + + function = FunctionToOptimize( + function_name="compute", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="ManualInit")] + ) + + try: + instrument_codeflash_capture(function, {}, test_path.parent) + modified_code = test_path.read_text() + assert modified_code.strip() == expected.strip() + finally: + test_path.unlink(missing_ok=True) + + def test_dataclass_with_explicit_init_still_instrumented(): """A dataclass that defines its own __init__ should still be instrumented normally.""" original_code = """ diff --git a/tests/test_parse_line_profile_test_output.py b/tests/test_parse_line_profile_test_output.py index e9ce3ef00..cd1e4935a 100644 --- a/tests/test_parse_line_profile_test_output.py +++ b/tests/test_parse_line_profile_test_output.py @@ -9,7 +9,6 @@ def test_parse_line_profile_results_non_python_java_json(): set_current_language(Language.JAVA) - with TemporaryDirectory() as tmpdir: tmp_path = Path(tmpdir) source_file = tmp_path / "Util.java"