diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index ed61e8c58..f7c5a425f 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -118,7 +118,7 @@ def optimize_python_code( # noqa: D417 if response.status_code == 200: optimizations_json = response.json()["optimizations"] - logger.info(f"Generated {len(optimizations_json)} candidates.") + logger.info(f"Generated {len(optimizations_json)} candidate optimizations.") console.rule() end_time = time.perf_counter() logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.") @@ -189,7 +189,7 @@ def optimize_python_code_line_profiler( # noqa: D417 if response.status_code == 200: optimizations_json = response.json()["optimizations"] - logger.info(f"Generated {len(optimizations_json)} candidates.") + logger.info(f"Generated {len(optimizations_json)} candidate optimizations.") console.rule() return [ OptimizedCandidate( diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index fefcb0822..87c54b148 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Optional +import git import requests import sentry_sdk from pydantic.json import pydantic_encoder @@ -191,3 +192,35 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]: return {} return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()} + + +def is_function_being_optimized_again( + owner: str, repo: str, pr_number: int, code_contexts: list[dict[str, str]] +) -> Any: # noqa: ANN401 + """Check if the function being optimized is being optimized again.""" + response = make_cfapi_request( + "/is-already-optimized", + "POST", + {"owner": owner, "repo": repo, "pr_number": pr_number, "code_contexts": code_contexts}, + ) + response.raise_for_status() + return response.json() + + +def add_code_context_hash(code_context_hash: str) -> None: + """Add code context to the DB cache.""" + pr_number = get_pr_number() + if pr_number is None: + return + try: + owner, repo = get_repo_owner_and_name() + pr_number = get_pr_number() + except git.exc.InvalidGitRepositoryError: + return + + if owner and repo and pr_number is not None: + make_cfapi_request( + "/add-code-hash", + "POST", + {"owner": owner, "repo": repo, "pr_number": pr_number, "code_hash": code_context_hash}, + ) diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index fe2fdcdd1..34d50f268 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -66,18 +66,34 @@ def code_print(code_str: str) -> None: @contextmanager -def progress_bar(message: str, *, transient: bool = False) -> Generator[TaskID, None, None]: - """Display a progress bar with a spinner and elapsed time.""" - progress = Progress( - SpinnerColumn(next(spinners)), - *Progress.get_default_columns(), - TimeElapsedColumn(), - console=console, - transient=transient, - ) - task = progress.add_task(message, total=None) - with progress: - yield task +def progress_bar( + message: str, *, transient: bool = False, revert_to_print: bool = False +) -> Generator[TaskID, None, None]: + """Display a progress bar with a spinner and elapsed time. + + If revert_to_print is True, falls back to printing a single logger.info message + instead of showing a progress bar. + """ + if revert_to_print: + logger.info(message) + + # Create a fake task ID since we still need to yield something + class DummyTask: + def __init__(self) -> None: + self.id = 0 + + yield DummyTask().id + else: + progress = Progress( + SpinnerColumn(next(spinners)), + *Progress.get_default_columns(), + TimeElapsedColumn(), + console=console, + transient=transient, + ) + task = progress.add_task(message, total=None) + with progress: + yield task @contextmanager diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index 3e0acafcb..0b8f54204 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -9,3 +9,4 @@ TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget COVERAGE_THRESHOLD = 60.0 MIN_TESTCASE_PASSED_THRESHOLD = 6 +REPEAT_OPTIMIZATION_PROBABILITY = 0.1 diff --git a/codeflash/code_utils/git_utils.py b/codeflash/code_utils/git_utils.py index 8333c1099..875b261cd 100644 --- a/codeflash/code_utils/git_utils.py +++ b/codeflash/code_utils/git_utils.py @@ -5,6 +5,7 @@ import sys import tempfile import time +from functools import cache from io import StringIO from pathlib import Path from typing import TYPE_CHECKING @@ -79,6 +80,7 @@ def get_git_remotes(repo: Repo) -> list[str]: return [remote.name for remote in repository.remotes] +@cache def get_repo_owner_and_name(repo: Repo | None = None, git_remote: str | None = "origin") -> tuple[str, str]: remote_url = get_remote_url(repo, git_remote) # call only once remote_url = remote_url.removesuffix(".git") if remote_url.endswith(".git") else remote_url diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 934d3053b..2971b4e7f 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -1,9 +1,10 @@ from __future__ import annotations +import hashlib import os from collections import defaultdict from itertools import chain -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import libcst as cst @@ -31,8 +32,8 @@ def get_code_optimization_context( function_to_optimize: FunctionToOptimize, project_root_path: Path, - optim_token_limit: int = 8000, - testgen_token_limit: int = 8000, + optim_token_limit: int = 16000, + testgen_token_limit: int = 16000, ) -> CodeOptimizationContext: # Get FunctionSource representation of helpers of FTO helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi( @@ -73,6 +74,13 @@ def get_code_optimization_context( remove_docstrings=False, code_context_type=CodeContextType.READ_ONLY, ) + hashing_code_context = extract_code_markdown_context_from_files( + helpers_of_fto_dict, + helpers_of_helpers_dict, + project_root_path, + remove_docstrings=True, + code_context_type=CodeContextType.HASHING, + ) # Handle token limits final_read_writable_tokens = encoded_tokens_len(final_read_writable_code) @@ -125,11 +133,15 @@ def get_code_optimization_context( testgen_context_code_tokens = encoded_tokens_len(testgen_context_code) if testgen_context_code_tokens > testgen_token_limit: raise ValueError("Testgen code context has exceeded token limit, cannot proceed") + code_hash_context = hashing_code_context.markdown + code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest() return CodeOptimizationContext( testgen_context_code=testgen_context_code, read_writable_code=final_read_writable_code, read_only_context_code=read_only_context_code, + hashing_code_context=code_hash_context, + hashing_code_context_hash=code_hash, helper_functions=helpers_of_fto_list, preexisting_objects=preexisting_objects, ) @@ -309,8 +321,8 @@ def extract_code_markdown_context_from_files( logger.debug(f"Error while getting read-only code: {e}") continue if code_context.strip(): - code_context_with_imports = CodeString( - code=add_needed_imports_from_module( + if code_context_type != CodeContextType.HASHING: + code_context = add_needed_imports_from_module( src_module_code=original_code, dst_module_code=code_context, src_path=file_path, @@ -319,10 +331,9 @@ def extract_code_markdown_context_from_files( helper_functions=list( helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set()) ), - ), - file_path=file_path.relative_to(project_root_path), - ) - code_context_markdown.code_strings.append(code_context_with_imports) + ) + code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path)) + code_context_markdown.code_strings.append(code_string_context) # Extract code from file paths containing helpers of helpers for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items(): try: @@ -343,18 +354,17 @@ def extract_code_markdown_context_from_files( continue if code_context.strip(): - code_context_with_imports = CodeString( - code=add_needed_imports_from_module( + if code_context_type != CodeContextType.HASHING: + code_context = add_needed_imports_from_module( src_module_code=original_code, dst_module_code=code_context, src_path=file_path, dst_path=file_path, project_root=project_root_path, helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())), - ), - file_path=file_path.relative_to(project_root_path), - ) - code_context_markdown.code_strings.append(code_context_with_imports) + ) + code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path)) + code_context_markdown.code_strings.append(code_string_context) return code_context_markdown @@ -492,6 +502,8 @@ def parse_code_and_prune_cst( filtered_node, found_target = prune_cst_for_testgen_code( module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings ) + elif code_context_type == CodeContextType.HASHING: + filtered_node, found_target = prune_cst_for_code_hashing(module, target_functions) else: raise ValueError(f"Unknown code_context_type: {code_context_type}") # noqa: EM102 @@ -583,6 +595,90 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911 return (node.with_changes(**updates) if updates else node), True +def prune_cst_for_code_hashing( # noqa: PLR0911 + node: cst.CSTNode, target_functions: set[str], prefix: str = "" +) -> tuple[cst.CSTNode | None, bool]: + """Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions. + + Returns + ------- + (filtered_node, found_target): + filtered_node: The modified CST node or None if it should be removed. + found_target: True if a target function was found in this node's subtree. + + """ + if isinstance(node, (cst.Import, cst.ImportFrom)): + return None, False + + if isinstance(node, cst.FunctionDef): + qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value + if qualified_name in target_functions: + new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body + return node.with_changes(body=new_body), True + return None, False + + if isinstance(node, cst.ClassDef): + # Do not recurse into nested classes + if prefix: + return None, False + # Assuming always an IndentedBlock + if not isinstance(node.body, cst.IndentedBlock): + raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 + class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value + new_class_body: list[cst.CSTNode] = [] + found_target = False + + for stmt in node.body.body: + if isinstance(stmt, cst.FunctionDef): + qualified_name = f"{class_prefix}.{stmt.name.value}" + if qualified_name in target_functions: + stmt_with_changes = stmt.with_changes( + body=remove_docstring_from_body(cast("cst.IndentedBlock", stmt.body)) + ) + new_class_body.append(stmt_with_changes) + found_target = True + # If no target functions found, remove the class entirely + if not new_class_body or not found_target: + return None, False + return node.with_changes( + body=cst.IndentedBlock(cast("list[cst.BaseStatement]", new_class_body)) + ) if new_class_body else None, found_target + + # For other nodes, we preserve them only if they contain target functions in their children. + section_names = get_section_names(node) + if not section_names: + return node, False + + updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {} + found_any_target = False + + for section in section_names: + original_content = getattr(node, section, None) + if isinstance(original_content, (list, tuple)): + new_children = [] + section_found_target = False + for child in original_content: + filtered, found_target = prune_cst_for_code_hashing(child, target_functions, prefix) + if filtered: + new_children.append(filtered) + section_found_target |= found_target + + if section_found_target: + found_any_target = True + updates[section] = new_children + elif original_content is not None: + filtered, found_target = prune_cst_for_code_hashing(original_content, target_functions, prefix) + if found_target: + found_any_target = True + if filtered: + updates[section] = filtered + + if not found_any_target: + return None, False + + return (node.with_changes(**updates) if updates else node), True + + def prune_cst_for_read_only_code( # noqa: PLR0911 node: cst.CSTNode, target_functions: set[str], diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 931b3a05a..c50a0ad49 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -14,14 +14,15 @@ import libcst as cst from pydantic.dataclasses import dataclass -from codeflash.api.cfapi import get_blocklisted_functions +from codeflash.api.cfapi import get_blocklisted_functions, is_function_being_optimized_again from codeflash.cli_cmds.console import DEBUG_MODE, console, logger from codeflash.code_utils.code_utils import ( is_class_defined_in_file, module_name_from_file_path, path_belongs_to_site_packages, ) -from codeflash.code_utils.git_utils import get_git_diff +from codeflash.code_utils.env_utils import get_pr_number +from codeflash.code_utils.git_utils import get_git_diff, get_repo_owner_and_name from codeflash.code_utils.time_utils import humanize_runtime from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.models.models import FunctionParent @@ -31,6 +32,7 @@ from libcst import CSTNode from libcst.metadata import CodeRange + from codeflash.models.models import CodeOptimizationContext from codeflash.verification.verification_utils import TestConfig @@ -417,6 +419,59 @@ def inspect_top_level_functions_or_methods( ) +def was_function_previously_optimized( + function_to_optimize: FunctionToOptimize, code_context: CodeOptimizationContext +) -> bool: + """Check which functions have already been optimized and filter them out. + + This function calls the optimization API to: + 1. Check which functions are already optimized + 2. Log new function hashes to the database + 3. Return only functions that need optimization + + Returns: + Tuple of (filtered_functions_dict, remaining_count) + + """ + # Check optimization status if repository info is provided + # already_optimized_count = 0 + try: + owner, repo = get_repo_owner_and_name() + except git.exc.InvalidGitRepositoryError: + logger.warning("No git repository found") + owner, repo = None, None + pr_number = get_pr_number() + + if not owner or not repo or pr_number is None: + return False + + code_contexts = [] + + func_hash = code_context.hashing_code_context_hash + # Use a unique path identifier that includes function info + + code_contexts.append( + { + "file_path": function_to_optimize.file_path, + "function_name": function_to_optimize.qualified_name, + "code_hash": func_hash, + } + ) + + if not code_contexts: + return False + + try: + result = is_function_being_optimized_again(owner, repo, pr_number, code_contexts) + already_optimized_paths: list[tuple[str, str]] = result.get("already_optimized_tuples", []) + return len(already_optimized_paths) > 0 + + except Exception as e: + logger.warning(f"Failed to check optimization status: {e}") + # Return all functions if API call fails + return False + + def filter_functions( modified_functions: dict[Path, list[FunctionToOptimize]], tests_root: Path, @@ -426,14 +481,15 @@ def filter_functions( previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None, disable_logs: bool = False, # noqa: FBT001, FBT002 ) -> tuple[dict[Path, list[FunctionToOptimize]], int]: + filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {} blocklist_funcs = get_blocklisted_functions() logger.debug(f"Blocklisted functions: {blocklist_funcs}") # Remove any function that we don't want to optimize + # already_optimized_paths = check_optimization_status(modified_functions, project_root) # Ignore files with submodule path, cache the submodule paths submodule_paths = ignored_submodule_paths(module_root) - filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {} functions_count: int = 0 test_functions_removed_count: int = 0 non_modules_removed_count: int = 0 @@ -445,6 +501,7 @@ def filter_functions( previous_checkpoint_functions_removed_count: int = 0 tests_root_str = str(tests_root) module_root_str = str(module_root) + # We desperately need Python 3.10+ only support to make this code readable with structural pattern matching for file_path_path, functions in modified_functions.items(): _functions = functions @@ -473,6 +530,7 @@ def filter_functions( except SyntaxError: malformed_paths_count += 1 continue + if blocklist_funcs: functions_tmp = [] for function in _functions: diff --git a/codeflash/models/models.py b/codeflash/models/models.py index b250d2474..02db2d0b6 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -157,6 +157,8 @@ class CodeOptimizationContext(BaseModel): testgen_context_code: str = "" read_writable_code: str = Field(min_length=1) read_only_context_code: str = "" + hashing_code_context: str = "" + hashing_code_context_hash: str = "" helper_functions: list[FunctionSource] preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] @@ -165,6 +167,7 @@ class CodeContextType(str, Enum): READ_WRITABLE = "READ_WRITABLE" READ_ONLY = "READ_ONLY" TESTGEN = "TESTGEN" + HASHING = "HASHING" class OptimizedCandidateResult(BaseModel): @@ -421,7 +424,7 @@ def id(self) -> str: ) @staticmethod - def from_str_id(string_id: str, iteration_id: Optional[str] = None) -> InvocationId: + def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId: components = string_id.split(":") assert len(components) == 4 second_components = components[1].split(".") diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 9f5781697..4edbf8974 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -3,6 +3,7 @@ import ast import concurrent.futures import os +import random import subprocess import time import uuid @@ -18,6 +19,7 @@ from rich.tree import Tree from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient +from codeflash.api.cfapi import add_code_context_hash from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, progress_bar from codeflash.code_utils import env_utils @@ -39,12 +41,14 @@ INDIVIDUAL_TESTCASE_TIMEOUT, N_CANDIDATES, N_TESTS_TO_GENERATE, + REPEAT_OPTIMIZATION_PROBABILITY, TOTAL_LOOPING_TIME, ) from codeflash.code_utils.edit_generated_tests import ( add_runtime_comments_to_generated_tests, remove_functions_from_generated_tests, ) +from codeflash.code_utils.env_utils import get_pr_number from codeflash.code_utils.formatter import format_code, sort_imports from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.code_utils.line_profile_utils import add_decorator_imports @@ -52,6 +56,7 @@ from codeflash.code_utils.time_utils import humanize_runtime from codeflash.context import code_context_extractor from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions +from codeflash.discovery.functions_to_optimize import was_function_previously_optimized from codeflash.either import Failure, Success, is_successful from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( @@ -155,8 +160,16 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 with helper_function_path.open(encoding="utf8") as f: helper_code = f.read() original_helper_code[helper_function_path] = helper_code + if has_any_async_functions(code_context.read_writable_code): return Failure("Codeflash does not support async functions in the code to optimize.") + # Random here means that we still attempt optimization with a fractional chance to see if + # last time we could not find an optimization, maybe this time we do. + # Random is before as a performance optimization, swapping the two 'and' statements has the same effect + if random.random() > REPEAT_OPTIMIZATION_PROBABILITY and was_function_previously_optimized( # noqa: S311 + self.function_to_optimize, code_context + ): + return Failure("Function optimization previously attempted, skipping.") code_print(code_context.read_writable_code) generated_test_paths = [ @@ -175,6 +188,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 with progress_bar( f"Generating new tests and optimizations for function {self.function_to_optimize.function_name}", transient=True, + revert_to_print=bool(get_pr_number()), ): generated_results = self.generate_tests_and_optimizations( testgen_context_code=code_context.testgen_context_code, @@ -375,6 +389,10 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 ) self.log_successful_optimization(explanation, generated_tests, exp_type) + # Add function to code context hash if in gh actions + + add_code_context_hash(code_context.hashing_code_context_hash) + if self.args.override_fixtures: restore_conftest(original_conftest_content) if not best_optimization: @@ -684,6 +702,8 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: testgen_context_code=new_code_ctx.testgen_context_code, read_writable_code=new_code_ctx.read_writable_code, read_only_context_code=new_code_ctx.read_only_context_code, + hashing_code_context=new_code_ctx.hashing_code_context, + hashing_code_context_hash=new_code_ctx.hashing_code_context_hash, helper_functions=new_code_ctx.helper_functions, # only functions that are read writable preexisting_objects=new_code_ctx.preexisting_objects, ) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 55ab14c35..0401efe31 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -11,6 +11,7 @@ from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.cli_cmds.console import console, logger, progress_bar from codeflash.code_utils import env_utils +from codeflash.code_utils.env_utils import get_pr_number from codeflash.either import is_successful from codeflash.models.models import ValidCode from codeflash.telemetry.posthog_cf import ph @@ -110,7 +111,12 @@ def run(self) -> None: from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table - with progress_bar(f"Running benchmarks in {self.args.benchmarks_root}", transient=True): + console.rule() + with progress_bar( + f"Running benchmarks in {self.args.benchmarks_root}", + transient=True, + revert_to_print=bool(get_pr_number()), + ): # Insert decorator file_path_to_source_code = defaultdict(str) for file in file_to_funcs_to_optimize: diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 90356ac10..2d4dd56cb 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -6,7 +6,6 @@ from pathlib import Path import pytest - from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import FunctionParent @@ -30,6 +29,7 @@ def __init__(self, name): def nested_method(self): return self.name + def main_method(): return "hello" @@ -81,8 +81,9 @@ def test_code_replacement10() -> None: code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent) qualified_names = {func.qualified_name for func in code_ctx.helper_functions} - assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here + assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ from __future__ import annotations @@ -106,8 +107,26 @@ def main_method(self): expected_read_only_context = """ """ + expected_hashing_context = f""" +```python:{file_path.relative_to(file_path.parent)} +class HelperClass: + + def helper_method(self): + return self.name + + +class MainClass: + + def main_method(self): + self.name = HelperClass.NestedClass("test").nested_method() + return HelperClass(self.name).helper_method() +``` +""" + assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() + def test_class_method_dependencies() -> None: file_path = Path(__file__).resolve() @@ -122,6 +141,8 @@ def test_class_method_dependencies() -> None: code_ctx = get_code_optimization_context(function_to_optimize, file_path.parent.resolve()) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context + expected_read_write_context = """ from __future__ import annotations from collections import defaultdict @@ -153,8 +174,36 @@ def topologicalSort(self): """ expected_read_only_context = "" + + expected_hashing_context = f""" +```python:{file_path.relative_to(file_path.parent.resolve())} +class Graph: + + def topologicalSortUtil(self, v, visited, stack): + visited[v] = True + + for i in self.graph[v]: + if visited[i] == False: + self.topologicalSortUtil(i, visited, stack) + + stack.insert(0, v) + + def topologicalSort(self): + visited = [False] * self.V + stack = [] + + for i in range(self.V): + if visited[i] == False: + self.topologicalSortUtil(i, visited, stack) + + # Print contents of stack + return stack +``` +""" + assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_bubble_sort_helper() -> None: @@ -176,6 +225,7 @@ def test_bubble_sort_helper() -> None: code_ctx = get_code_optimization_context(function_to_optimize, Path(__file__).resolve().parent.parent) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ import math @@ -196,8 +246,24 @@ def sort_from_another_file(arr): """ expected_read_only_context = "" + expected_hashing_context = """ +```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py +def sorter(arr): + arr.sort() + x = math.sqrt(2) + print(x) + return arr +``` +```python:code_to_optimize/code_directories/retriever/bubble_sort_imported.py +def sort_from_another_file(arr): + sorted_arr = sorter(arr) + return sorted_arr +``` +""" + assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_flavio_typed_code_helper() -> None: @@ -366,7 +432,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: lifespan=self.__duration__, ) ''' - with tempfile.NamedTemporaryFile(mode="w") as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: f.write(code) f.flush() file_path = Path(f.name).resolve() @@ -391,6 +457,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): @@ -543,8 +610,67 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]): __backend__: _CacheBackendT ``` ''' + expected_hashing_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} +class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): + + def get_cache_or_call( + self, + *, + func: Callable[_P, Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], + lifespan: datetime.timedelta, + ) -> Any: # noqa: ANN401 + if os.environ.get("NO_CACHE"): + return func(*args, **kwargs) + + try: + key = self.hash_key(func=func, args=args, kwargs=kwargs) + except: # noqa: E722 + # If we can't create a cache key, we should just call the function. + logging.warning("Failed to hash cache key for function: %s", func) + return func(*args, **kwargs) + result_pair = self.get(key=key) + + if result_pair is not None: + cached_time, result = result_pair + if not os.environ.get("RE_CACHE") and ( + datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005 + ): + try: + return self.decode(data=result) + except CacheBackendDecodeError as e: + logging.warning("Failed to decode cache data: %s", e) + # If decoding fails we will treat this as a cache miss. + # This might happens if underlying class definition of the data changes. + self.delete(key=key) + result = func(*args, **kwargs) + try: + self.put(key=key, data=self.encode(data=result)) + except CacheBackendEncodeError as e: + logging.warning("Failed to encode cache data: %s", e) + # If encoding fails, we should still return the result. + return result + + +class _PersistentCache(Generic[_P, _R, _CacheBackendT]): + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: + if "NO_CACHE" in os.environ: + return self.__wrapped__(*args, **kwargs) + os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True) + return self.__backend__.get_cache_or_call( + func=self.__wrapped__, + args=args, + kwargs=kwargs, + lifespan=self.__duration__, + ) +``` +""" assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_example_class() -> None: @@ -592,6 +718,8 @@ def helper_method(self): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context + expected_read_write_context = """ class MyClass: def __init__(self): @@ -618,8 +746,21 @@ def __repr__(self): return "HelperClass" + str(self.x) ``` """ + expected_hashing_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} +class MyClass: + def target_method(self): + y = HelperClass().helper_method() + +class HelperClass: + def helper_method(self): + return self.x +``` +""" + assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_example_class_token_limit_1() -> None: @@ -672,6 +813,7 @@ def helper_method(self): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context # In this scenario, the read-only code context is too long, so the read-only docstrings are removed. expected_read_write_context = """ class MyClass: @@ -697,9 +839,21 @@ class HelperClass: def __repr__(self): return "HelperClass" + str(self.x) ``` +""" + expected_hashing_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} +class MyClass: + def target_method(self): + y = HelperClass().helper_method() + +class HelperClass: + def helper_method(self): + return self.x +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_example_class_token_limit_2() -> None: @@ -752,6 +906,7 @@ def helper_method(self): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context # In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely. expected_read_write_context = """ class MyClass: @@ -769,8 +924,20 @@ def helper_method(self): return self.x """ expected_read_only_context = "" + expected_hashing_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} +class MyClass: + def target_method(self): + y = HelperClass().helper_method() + +class HelperClass: + def helper_method(self): + return self.x +``` +""" assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_example_class_token_limit_3() -> None: @@ -823,6 +990,7 @@ def helper_method(self): with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + def test_example_class_token_limit_4() -> None: string_filler = " ".join( ["This is a long string that will be used to fill up the token limit." for _ in range(1000)] @@ -875,6 +1043,7 @@ def helper_method(self): with pytest.raises(ValueError, match="Testgen code context has exceeded token limit, cannot proceed"): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + def test_repo_helper() -> None: project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" path_to_file = project_root / "main.py" @@ -889,6 +1058,7 @@ def test_repo_helper() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ import math import requests @@ -938,9 +1108,36 @@ def __repr__(self) -> str: \"\"\"Return a string representation of the DataProcessor.\"\"\" return f"DataProcessor(default_prefix={{self.default_prefix!r}})" ``` +""" + expected_hashing_context = f""" +```python:{path_to_utils.relative_to(project_root)} +class DataProcessor: + + def process_data(self, raw_data: str) -> str: + return raw_data.upper() + + def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str: + return prefix + data +``` +```python:{path_to_file.relative_to(project_root)} +def fetch_and_process_data(): + # Use the global variable for the request + response = requests.get(API_URL) + response.raise_for_status() + + raw_data = response.text + + # Use code from another file (utils.py) + processor = DataProcessor() + processed = processor.process_data(raw_data) + processed = processor.add_prefix(processed) + + return processed +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_repo_helper_of_helper() -> None: @@ -958,6 +1155,7 @@ def test_repo_helper_of_helper() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ import math from transform_utils import DataTransformer @@ -1014,10 +1212,36 @@ def transform(self, data): self.data = data return self.data ``` +""" + expected_hashing_context = f""" +```python:{path_to_utils.relative_to(project_root)} +class DataProcessor: + + def process_data(self, raw_data: str) -> str: + return raw_data.upper() + + def transform_data(self, data: str) -> str: + return DataTransformer().transform(data) +``` +```python:{path_to_file.relative_to(project_root)} +def fetch_and_transform_data(): + # Use the global variable for the request + response = requests.get(API_URL) + + raw_data = response.text + + # Use code from another file (utils.py) + processor = DataProcessor() + processed = processor.process_data(raw_data) + transformed = processor.transform_data(processed) + + return transformed +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_repo_helper_of_helper_same_class() -> None: @@ -1034,6 +1258,7 @@ def test_repo_helper_of_helper_same_class() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ import math from transform_utils import DataTransformer @@ -1078,10 +1303,25 @@ def __repr__(self) -> str: return f"DataProcessor(default_prefix={{self.default_prefix!r}})" ``` +""" + expected_hashing_context = f""" +```python:transform_utils.py +class DataTransformer: + + def transform_using_own_method(self, data): + return self.transform(data) +``` +```python:{path_to_utils.relative_to(project_root)} +class DataProcessor: + + def transform_data_own_method(self, data: str) -> str: + return DataTransformer().transform_using_own_method(data) +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_repo_helper_of_helper_same_file() -> None: @@ -1098,6 +1338,7 @@ def test_repo_helper_of_helper_same_file() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ import math from transform_utils import DataTransformer @@ -1137,10 +1378,25 @@ def __repr__(self) -> str: \"\"\"Return a string representation of the DataProcessor.\"\"\" return f"DataProcessor(default_prefix={{self.default_prefix!r}})" ``` +""" + expected_hashing_context = f""" +```python:transform_utils.py +class DataTransformer: + + def transform_using_same_file_function(self, data): + return update_data(data) +``` +```python:{path_to_utils.relative_to(project_root)} +class DataProcessor: + + def transform_data_same_file_function(self, data: str) -> str: + return DataTransformer().transform_using_same_file_function(data) +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_repo_helper_all_same_file() -> None: @@ -1156,6 +1412,7 @@ def test_repo_helper_all_same_file() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ class DataTransformer: def __init__(self): @@ -1181,10 +1438,27 @@ def transform(self, data): return self.data ``` +""" + expected_hashing_context = f""" +```python:{path_to_transform_utils.relative_to(project_root)} +class DataTransformer: + + def transform_using_own_method(self, data): + return self.transform(data) + + def transform_data_all_same_file(self, data): + new_data = update_data(data) + return self.transform_using_own_method(new_data) + + +def update_data(data): + return data + " updated" +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() def test_repo_helper_circular_dependency() -> None: @@ -1201,6 +1475,7 @@ def test_repo_helper_circular_dependency() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ import math from transform_utils import DataTransformer @@ -1240,10 +1515,26 @@ def __repr__(self) -> str: return f"DataProcessor(default_prefix={{self.default_prefix!r}})" ``` +""" + expected_hashing_context = f""" +```python:utils.py +class DataProcessor: + + def circular_dependency(self, data: str) -> str: + return DataTransformer().circular_dependency(data) +``` +```python:{path_to_transform_utils.relative_to(project_root)} +class DataTransformer: + + def circular_dependency(self, data): + return DataProcessor().circular_dependency(data) +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() + def test_indirect_init_helper() -> None: code = """ @@ -1282,6 +1573,7 @@ def outside_method(): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_write_context = """ class MyClass: def __init__(self): @@ -1295,9 +1587,18 @@ def target_method(self): def outside_method(): return 1 ``` +""" + expected_hashing_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} +class MyClass: + def target_method(self): + return self.x + self.y +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() + def test_direct_module_import() -> None: project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" @@ -1311,9 +1612,9 @@ def test_direct_module_import() -> None: ending_line=None, ) - code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context expected_read_only_context = """ ```python:utils.py @@ -1336,6 +1637,26 @@ def transform_data(self, data: str) -> str: \"\"\"Transform the processed data\"\"\" return DataTransformer().transform(data) ```""" + expected_hashing_context = """ +```python:main.py +def fetch_and_transform_data(): + # Use the global variable for the request + response = requests.get(API_URL) + + raw_data = response.text + + # Use code from another file (utils.py) + processor = DataProcessor() + processed = processor.process_data(raw_data) + transformed = processor.transform_data(processed) + + return transformed +``` +```python:import_test.py +def function_to_optimize(): + return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data() +``` +""" expected_read_write_context = """ import requests from globals import API_URL @@ -1362,9 +1683,11 @@ def function_to_optimize(): """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() + def test_module_import_optimization() -> None: - main_code = ''' + main_code = """ import utility_module class Calculator: @@ -1391,9 +1714,9 @@ def calculate(self, operation, x, y): return self.subtract(x, y) else: return None -''' +""" - utility_module_code = ''' + utility_module_code = """ import sys import platform import logging @@ -1466,7 +1789,7 @@ def get_system_details(): "default_precision": DEFAULT_PRECISION, "python_version": sys.version } -''' +""" # Create a temporary directory for the test with tempfile.TemporaryDirectory() as temp_dir: @@ -1515,6 +1838,7 @@ def get_system_details(): # Get the code optimization context code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context # The expected contexts expected_read_write_context = """ import utility_module @@ -1579,13 +1903,34 @@ def select_precision(precision, fallback_precision): else: return DEFAULT_PRECISION ``` +""" + expected_hashing_context = """ +```python:main_module.py +class Calculator: + + def add(self, a, b): + return a + b + + def subtract(self, a, b): + return a - b + + def calculate(self, operation, x, y): + if operation == "add": + return self.add(x, y) + elif operation == "subtract": + return self.subtract(x, y) + else: + return None +``` """ # Verify the contexts match the expected values assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() + def test_module_import_init_fto() -> None: - main_code = ''' + main_code = """ import utility_module class Calculator: @@ -1612,9 +1957,9 @@ def calculate(self, operation, x, y): return self.subtract(x, y) else: return None -''' +""" - utility_module_code = ''' + utility_module_code = """ import sys import platform import logging @@ -1687,7 +2032,7 @@ def get_system_details(): "default_precision": DEFAULT_PRECISION, "python_version": sys.version } -''' +""" # Create a temporary directory for the test with tempfile.TemporaryDirectory() as temp_dir: @@ -1791,4 +2136,336 @@ def __init__(self, precision="high", fallback_precision=None, mode="standard"): ``` """ assert read_write_context.strip() == expected_read_write_context.strip() - assert read_only_context.strip() == expected_read_only_context.strip() \ No newline at end of file + assert read_only_context.strip() == expected_read_only_context.strip() + + +def test_hashing_code_context_removes_imports_docstrings_and_init() -> None: + """Test that hashing context removes imports, docstrings, and __init__ methods properly.""" + code = ''' +import os +import sys +from pathlib import Path + +class MyClass: + """A class with a docstring.""" + def __init__(self, value): + """Initialize with a value.""" + self.value = value + + def target_method(self): + """Target method with docstring.""" + result = self.helper_method() + helper_cls = HelperClass() + data = helper_cls.process_data() + return self.value * 2 + + def helper_method(self): + """Helper method with docstring.""" + return self.value + 1 + +class HelperClass: + """Helper class docstring.""" + def __init__(self): + """Helper init method.""" + self.data = "test" + + def process_data(self): + """Process data method.""" + return self.data.upper() + +def standalone_function(): + """Standalone function.""" + return "standalone" +''' + with tempfile.NamedTemporaryFile(mode="w") as f: + f.write(code) + f.flush() + file_path = Path(f.name).resolve() + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="MyClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + hashing_context = code_ctx.hashing_code_context + + # Expected behavior based on current implementation: + # - Should not contain imports + # - Should remove docstrings from target functions (but currently doesn't - this is a bug) + # - Should not contain __init__ methods + # - Should contain target function and helper methods that are actually called + # - Should be formatted as markdown + + # Test that it's formatted as markdown + assert hashing_context.startswith("```python:") + assert hashing_context.endswith("```") + + # Test basic structure requirements + assert "import" not in hashing_context # Should not contain imports + assert "__init__" not in hashing_context # Should not contain __init__ methods + assert "target_method" in hashing_context # Should contain target function + assert "standalone_function" not in hashing_context # Should not contain unused functions + + # Test that helper functions are included when they're called + assert "helper_method" in hashing_context # Should contain called helper method + assert "process_data" in hashing_context # Should contain called helper method + + # Test for docstring removal (this should pass when implementation is fixed) + # Currently this will fail because docstrings are not being removed properly + assert '"""Target method with docstring."""' not in hashing_context, ( + "Docstrings should be removed from target functions" + ) + assert '"""Helper method with docstring."""' not in hashing_context, ( + "Docstrings should be removed from helper functions" + ) + assert '"""Process data method."""' not in hashing_context, ( + "Docstrings should be removed from helper class methods" + ) + + +def test_hashing_code_context_with_nested_classes() -> None: + """Test that hashing context handles nested classes properly (should exclude them).""" + code = ''' +class OuterClass: + """Outer class docstring.""" + def __init__(self): + """Outer init.""" + self.value = 1 + + def target_method(self): + """Target method.""" + return self.NestedClass().nested_method() + + class NestedClass: + """Nested class - should be excluded.""" + def __init__(self): + self.nested_value = 2 + + def nested_method(self): + return self.nested_value +''' + with tempfile.NamedTemporaryFile(mode="w") as f: + f.write(code) + f.flush() + file_path = Path(f.name).resolve() + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="OuterClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + hashing_context = code_ctx.hashing_code_context + + # Test basic requirements + assert hashing_context.startswith("```python:") + assert hashing_context.endswith("```") + assert "target_method" in hashing_context + assert "__init__" not in hashing_context # Should not contain __init__ methods + + # Verify nested classes are excluded from the hashing context + # The prune_cst_for_code_hashing function should not recurse into nested classes + assert "class NestedClass:" not in hashing_context # Nested class definition should not be present + + # The target method will reference NestedClass, but the actual nested class definition should not be included + # The call to self.NestedClass().nested_method() should be in the target method but the nested class itself excluded + target_method_call_present = "self.NestedClass().nested_method()" in hashing_context + assert target_method_call_present, "The target method should contain the call to nested class" + + # But the actual nested method definition should not be present + nested_method_definition_present = "def nested_method(self):" in hashing_context + assert not nested_method_definition_present, "Nested method definition should not be present in hashing context" + + +def test_hashing_code_context_hash_consistency() -> None: + """Test that the same code produces the same hash.""" + code = """ +class TestClass: + def target_method(self): + return "test" +""" + with tempfile.NamedTemporaryFile(mode="w") as f: + f.write(code) + f.flush() + file_path = Path(f.name).resolve() + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="TestClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + # Generate context twice + code_ctx1 = get_code_optimization_context(function_to_optimize, opt.args.project_root) + code_ctx2 = get_code_optimization_context(function_to_optimize, opt.args.project_root) + + # Hash should be consistent + assert code_ctx1.hashing_code_context_hash == code_ctx2.hashing_code_context_hash + assert code_ctx1.hashing_code_context == code_ctx2.hashing_code_context + + # Hash should be valid SHA256 + import hashlib + + expected_hash = hashlib.sha256(code_ctx1.hashing_code_context.encode("utf-8")).hexdigest() + assert code_ctx1.hashing_code_context_hash == expected_hash + + +def test_hashing_code_context_different_code_different_hash() -> None: + """Test that different code produces different hashes.""" + code1 = """ +class TestClass: + def target_method(self): + return "test1" +""" + code2 = """ +class TestClass: + def target_method(self): + return "test2" +""" + + with tempfile.NamedTemporaryFile(mode="w") as f1, tempfile.NamedTemporaryFile(mode="w") as f2: + f1.write(code1) + f1.flush() + f2.write(code2) + f2.flush() + + file_path1 = Path(f1.name).resolve() + file_path2 = Path(f2.name).resolve() + + opt1 = Optimizer( + Namespace( + project_root=file_path1.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + opt2 = Optimizer( + Namespace( + project_root=file_path2.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + + function_to_optimize1 = FunctionToOptimize( + function_name="target_method", + file_path=file_path1, + parents=[FunctionParent(name="TestClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + function_to_optimize2 = FunctionToOptimize( + function_name="target_method", + file_path=file_path2, + parents=[FunctionParent(name="TestClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + code_ctx1 = get_code_optimization_context(function_to_optimize1, opt1.args.project_root) + code_ctx2 = get_code_optimization_context(function_to_optimize2, opt2.args.project_root) + + # Different code should produce different hashes + assert code_ctx1.hashing_code_context_hash != code_ctx2.hashing_code_context_hash + assert code_ctx1.hashing_code_context != code_ctx2.hashing_code_context + + +def test_hashing_code_context_format_is_markdown() -> None: + """Test that hashing context is formatted as markdown.""" + code = """ +class SimpleClass: + def simple_method(self): + return 42 +""" + with tempfile.NamedTemporaryFile(mode="w") as f: + f.write(code) + f.flush() + file_path = Path(f.name).resolve() + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="simple_method", + file_path=file_path, + parents=[FunctionParent(name="SimpleClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + hashing_context = code_ctx.hashing_code_context + + # Should be formatted as markdown code block + assert hashing_context.startswith("```python:") + assert hashing_context.endswith("```") + + # Should contain the relative file path in the markdown header + relative_path = file_path.relative_to(opt.args.project_root) + assert str(relative_path) in hashing_context + + # Should contain the actual code between the markdown markers + lines = hashing_context.strip().split("\n") + assert lines[0].startswith("```python:") + assert lines[-1] == "```" + + # Code should be between the markers + code_lines = lines[1:-1] + code_content = "\n".join(code_lines) + assert "class SimpleClass:" in code_content + assert "def simple_method(self):" in code_content + assert "return 42" in code_content diff --git a/tests/test_git_utils.py b/tests/test_git_utils.py index f456a0d90..293ad5c9e 100644 --- a/tests/test_git_utils.py +++ b/tests/test_git_utils.py @@ -11,30 +11,35 @@ class TestGitUtils(unittest.TestCase): def test_test_get_repo_owner_and_name(self, mock_get_remote_url): # Test with a standard GitHub HTTPS URL mock_get_remote_url.return_value = "https://github.com/owner/repo.git" + get_repo_owner_and_name.cache_clear() owner, repo_name = get_repo_owner_and_name() assert owner == "owner" assert repo_name == "repo" # Test with a GitHub SSH URL mock_get_remote_url.return_value = "git@github.com:owner/repo.git" + get_repo_owner_and_name.cache_clear() owner, repo_name = get_repo_owner_and_name() assert owner == "owner" assert repo_name == "repo" # Test with another GitHub SSH URL mock_get_remote_url.return_value = "git@github.com:codeflash-ai/posthog.git" + get_repo_owner_and_name.cache_clear() owner, repo_name = get_repo_owner_and_name() assert owner == "codeflash-ai" assert repo_name == "posthog" # Test with a URL without the .git suffix mock_get_remote_url.return_value = "https://github.com/owner/repo" + get_repo_owner_and_name.cache_clear() owner, repo_name = get_repo_owner_and_name() assert owner == "owner" assert repo_name == "repo" # Test with another GitHub SSH URL mock_get_remote_url.return_value = "git@github.com:codeflash-ai/posthog/" + get_repo_owner_and_name.cache_clear() owner, repo_name = get_repo_owner_and_name() assert owner == "codeflash-ai" assert repo_name == "posthog"