Skip to content
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
1f7124a
WIP
misrasaurabh1 Apr 28, 2025
41378a0
WIP
misrasaurabh1 Apr 28, 2025
e9746c9
batch code hash check
dasarchan Jun 2, 2025
5760316
implemented hash check into filter_functions
dasarchan Jun 3, 2025
905b1a0
Merge branch 'main' into dont-optimize-repeatedly-gh-actions
Jun 3, 2025
2367160
removed prints, added cfapi.py func
dasarchan Jun 5, 2025
f2733b3
Merge branch 'dont-optimize-repeatedly-gh-actions' of https://github.…
dasarchan Jun 5, 2025
c1fb089
removed unused import
dasarchan Jun 5, 2025
3443404
Merge branch 'main' into dont-optimize-repeatedly-gh-actions
misrasaurabh1 Jun 6, 2025
eb3d305
fix no git error
misrasaurabh1 Jun 6, 2025
c862b4d
add low prob of repeating optimization
dasarchan Jun 6, 2025
96ee580
changes to cli for code context hash
dasarchan Jun 7, 2025
87fe086
update the cli
misrasaurabh1 Jun 7, 2025
4cb823e
added separate write route, changed return format for api route
dasarchan Jun 7, 2025
1cc39e3
merge
dasarchan Jun 7, 2025
dd8dceb
removed empty test file
dasarchan Jun 7, 2025
5989b26
updates
dasarchan Jun 7, 2025
5c0a028
Add a first version of hashing code context
misrasaurabh1 Jun 8, 2025
2686682
Might work?
misrasaurabh1 Jun 8, 2025
4f39794
get it working
misrasaurabh1 Jun 8, 2025
50f4c33
10% chance of optimizing again
misrasaurabh1 Jun 8, 2025
81f96ed
Merge branch 'main' into dont-optimize-repeatedly-gh-actions
misrasaurabh1 Jun 8, 2025
c856f1e
fix a bug
misrasaurabh1 Jun 8, 2025
b48ed5c
ruff fix
misrasaurabh1 Jun 8, 2025
9e14cfe
fix bugs with docstring removal
misrasaurabh1 Jun 8, 2025
5d4870f
fix a type
misrasaurabh1 Jun 8, 2025
2c1314d
fix more tests
misrasaurabh1 Jun 8, 2025
32a8001
fix types for python 3.9
misrasaurabh1 Jun 8, 2025
e2f1ba0
clearer message
misrasaurabh1 Jun 8, 2025
f6b3275
fix mypy types
misrasaurabh1 Jun 8, 2025
6ed9387
add more tests
misrasaurabh1 Jun 8, 2025
be1ef9b
fix for test
misrasaurabh1 Jun 8, 2025
9137921
double the context length
misrasaurabh1 Jun 8, 2025
797cba3
ruff revert
misrasaurabh1 Jun 8, 2025
d0f84f6
improve some github actions logging
misrasaurabh1 Jun 8, 2025
2d62171
some refactor
misrasaurabh1 Jun 9, 2025
226acd7
remove unncessary line
misrasaurabh1 Jun 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions codeflash/api/cfapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional

import git
import requests
import sentry_sdk
from pydantic.json import pydantic_encoder
Expand Down Expand Up @@ -191,3 +192,35 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]:
return {}

return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()}


def is_function_being_optimized_again(
owner: str, repo: str, pr_number: int, code_contexts: list[dict[str, str]]
) -> Any: # noqa: ANN401
"""Check if the function being optimized is being optimized again."""
response = make_cfapi_request(
"/is-already-optimized",
"POST",
{"owner": owner, "repo": repo, "pr_number": pr_number, "code_contexts": code_contexts},
)
response.raise_for_status()
return response.json()


def add_code_context_hash(code_context_hash: str) -> None:
"""Add code context to the DB cache."""
pr_number = get_pr_number()
if pr_number is None:
return
try:
owner, repo = get_repo_owner_and_name()
pr_number = get_pr_number()
except git.exc.InvalidGitRepositoryError:
return

if owner and repo and pr_number is not None:
make_cfapi_request(
"/add-code-hash",
"POST",
{"owner": owner, "repo": repo, "pr_number": pr_number, "code_hash": code_context_hash},
)
1 change: 1 addition & 0 deletions codeflash/code_utils/config_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget
COVERAGE_THRESHOLD = 60.0
MIN_TESTCASE_PASSED_THRESHOLD = 6
REPEAT_OPTIMIZATION_PROBABILITY = 0.1
2 changes: 2 additions & 0 deletions codeflash/code_utils/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import tempfile
import time
from functools import cache
from io import StringIO
from pathlib import Path
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -79,6 +80,7 @@ def get_git_remotes(repo: Repo) -> list[str]:
return [remote.name for remote in repository.remotes]


@cache
def get_repo_owner_and_name(repo: Repo | None = None, git_remote: str | None = "origin") -> tuple[str, str]:
remote_url = get_remote_url(repo, git_remote) # call only once
remote_url = remote_url.removesuffix(".git") if remote_url.endswith(".git") else remote_url
Expand Down
126 changes: 111 additions & 15 deletions codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import hashlib
import os
from collections import defaultdict
from itertools import chain
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import libcst as cst

Expand Down Expand Up @@ -31,8 +32,8 @@
def get_code_optimization_context(
function_to_optimize: FunctionToOptimize,
project_root_path: Path,
optim_token_limit: int = 8000,
testgen_token_limit: int = 8000,
optim_token_limit: int = 16000,
testgen_token_limit: int = 16000,
) -> CodeOptimizationContext:
# Get FunctionSource representation of helpers of FTO
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(
Expand Down Expand Up @@ -73,6 +74,13 @@ def get_code_optimization_context(
remove_docstrings=False,
code_context_type=CodeContextType.READ_ONLY,
)
hashing_code_context = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=True,
code_context_type=CodeContextType.HASHING,
)

# Handle token limits
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code)
Expand Down Expand Up @@ -125,11 +133,15 @@ def get_code_optimization_context(
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
if testgen_context_code_tokens > testgen_token_limit:
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
code_hash_context = hashing_code_context.markdown
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()

return CodeOptimizationContext(
testgen_context_code=testgen_context_code,
read_writable_code=final_read_writable_code,
read_only_context_code=read_only_context_code,
hashing_code_context=code_hash_context,
hashing_code_context_hash=code_hash,
helper_functions=helpers_of_fto_list,
preexisting_objects=preexisting_objects,
)
Expand Down Expand Up @@ -309,8 +321,8 @@ def extract_code_markdown_context_from_files(
logger.debug(f"Error while getting read-only code: {e}")
continue
if code_context.strip():
code_context_with_imports = CodeString(
code=add_needed_imports_from_module(
if code_context_type != CodeContextType.HASHING:
code_context = add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=code_context,
src_path=file_path,
Expand All @@ -319,10 +331,9 @@ def extract_code_markdown_context_from_files(
helper_functions=list(
helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())
),
),
file_path=file_path.relative_to(project_root_path),
)
code_context_markdown.code_strings.append(code_context_with_imports)
)
code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path))
code_context_markdown.code_strings.append(code_string_context)
# Extract code from file paths containing helpers of helpers
for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items():
try:
Expand All @@ -343,18 +354,17 @@ def extract_code_markdown_context_from_files(
continue

if code_context.strip():
code_context_with_imports = CodeString(
code=add_needed_imports_from_module(
if code_context_type != CodeContextType.HASHING:
code_context = add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=code_context,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())),
),
file_path=file_path.relative_to(project_root_path),
)
code_context_markdown.code_strings.append(code_context_with_imports)
)
code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path))
code_context_markdown.code_strings.append(code_string_context)
return code_context_markdown


Expand Down Expand Up @@ -492,6 +502,8 @@ def parse_code_and_prune_cst(
filtered_node, found_target = prune_cst_for_testgen_code(
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
)
elif code_context_type == CodeContextType.HASHING:
filtered_node, found_target = prune_cst_for_code_hashing(module, target_functions)
else:
raise ValueError(f"Unknown code_context_type: {code_context_type}") # noqa: EM102

Expand Down Expand Up @@ -583,6 +595,90 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
return (node.with_changes(**updates) if updates else node), True


def prune_cst_for_code_hashing( # noqa: PLR0911
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.

Returns
-------
(filtered_node, found_target):
filtered_node: The modified CST node or None if it should be removed.
found_target: True if a target function was found in this node's subtree.

"""
if isinstance(node, (cst.Import, cst.ImportFrom)):
return None, False

if isinstance(node, cst.FunctionDef):
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
if qualified_name in target_functions:
new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body
return node.with_changes(body=new_body), True
return None, False

if isinstance(node, cst.ClassDef):
# Do not recurse into nested classes
if prefix:
return None, False
# Assuming always an IndentedBlock
if not isinstance(node.body, cst.IndentedBlock):
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
new_class_body: list[cst.CSTNode] = []
found_target = False

for stmt in node.body.body:
if isinstance(stmt, cst.FunctionDef):
qualified_name = f"{class_prefix}.{stmt.name.value}"
if qualified_name in target_functions:
stmt_with_changes = stmt.with_changes(
body=remove_docstring_from_body(cast("cst.IndentedBlock", stmt.body))
)
new_class_body.append(stmt_with_changes)
found_target = True
# If no target functions found, remove the class entirely
if not new_class_body or not found_target:
return None, False
return node.with_changes(
body=cst.IndentedBlock(cast("list[cst.BaseStatement]", new_class_body))
) if new_class_body else None, found_target

# For other nodes, we preserve them only if they contain target functions in their children.
section_names = get_section_names(node)
if not section_names:
return node, False

updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {}
found_any_target = False

for section in section_names:
original_content = getattr(node, section, None)
if isinstance(original_content, (list, tuple)):
new_children = []
section_found_target = False
for child in original_content:
filtered, found_target = prune_cst_for_code_hashing(child, target_functions, prefix)
if filtered:
new_children.append(filtered)
section_found_target |= found_target

if section_found_target:
found_any_target = True
updates[section] = new_children
elif original_content is not None:
filtered, found_target = prune_cst_for_code_hashing(original_content, target_functions, prefix)
if found_target:
found_any_target = True
if filtered:
updates[section] = filtered

if not found_any_target:
return None, False

return (node.with_changes(**updates) if updates else node), True


def prune_cst_for_read_only_code( # noqa: PLR0911
node: cst.CSTNode,
target_functions: set[str],
Expand Down
64 changes: 61 additions & 3 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
import libcst as cst
from pydantic.dataclasses import dataclass

from codeflash.api.cfapi import get_blocklisted_functions
from codeflash.api.cfapi import get_blocklisted_functions, is_function_being_optimized_again
from codeflash.cli_cmds.console import DEBUG_MODE, console, logger
from codeflash.code_utils.code_utils import (
is_class_defined_in_file,
module_name_from_file_path,
path_belongs_to_site_packages,
)
from codeflash.code_utils.git_utils import get_git_diff
from codeflash.code_utils.env_utils import get_pr_number
from codeflash.code_utils.git_utils import get_git_diff, get_repo_owner_and_name
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.models.models import FunctionParent
Expand All @@ -31,6 +32,7 @@
from libcst import CSTNode
from libcst.metadata import CodeRange

from codeflash.models.models import CodeOptimizationContext
from codeflash.verification.verification_utils import TestConfig


Expand Down Expand Up @@ -417,6 +419,57 @@ def inspect_top_level_functions_or_methods(
)


def check_optimization_status(function_to_optimize: FunctionToOptimize, code_context: CodeOptimizationContext) -> bool:
"""Check which functions have already been optimized and filter them out.

This function calls the optimization API to:
1. Check which functions are already optimized
2. Log new function hashes to the database
3. Return only functions that need optimization

Returns:
Tuple of (filtered_functions_dict, remaining_count)

"""
# Check optimization status if repository info is provided
# already_optimized_count = 0
try:
owner, repo = get_repo_owner_and_name()
except git.exc.InvalidGitRepositoryError:
logger.warning("No git repository found")
owner, repo = None, None
pr_number = get_pr_number()

if not owner or not repo or pr_number is None:
return False

code_contexts = []

func_hash = code_context.hashing_code_context_hash
# Use a unique path identifier that includes function info

code_contexts.append(
{
"file_path": function_to_optimize.file_path,
"function_name": function_to_optimize.qualified_name,
"code_hash": func_hash,
}
)

if not code_contexts:
return False

try:
result = is_function_being_optimized_again(owner, repo, pr_number, code_contexts)
already_optimized_paths: list[tuple[str, str]] = result.get("already_optimized_tuples", [])
return len(already_optimized_paths) > 0

except Exception as e:
logger.warning(f"Failed to check optimization status: {e}")
# Return all functions if API call fails
return False


def filter_functions(
modified_functions: dict[Path, list[FunctionToOptimize]],
tests_root: Path,
Expand All @@ -426,25 +479,28 @@ def filter_functions(
previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None,
disable_logs: bool = False, # noqa: FBT001, FBT002
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {}
blocklist_funcs = get_blocklisted_functions()
logger.debug(f"Blocklisted functions: {blocklist_funcs}")
# Remove any function that we don't want to optimize
# already_optimized_paths = check_optimization_status(modified_functions, project_root)

# Ignore files with submodule path, cache the submodule paths
submodule_paths = ignored_submodule_paths(module_root)

filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {}
functions_count: int = 0
test_functions_removed_count: int = 0
non_modules_removed_count: int = 0
site_packages_removed_count: int = 0
ignore_paths_removed_count: int = 0
malformed_paths_count: int = 0
already_optimized_count: int = 0
submodule_ignored_paths_count: int = 0
blocklist_funcs_removed_count: int = 0
previous_checkpoint_functions_removed_count: int = 0
tests_root_str = str(tests_root)
module_root_str = str(module_root)

# We desperately need Python 3.10+ only support to make this code readable with structural pattern matching
for file_path_path, functions in modified_functions.items():
_functions = functions
Expand Down Expand Up @@ -473,6 +529,7 @@ def filter_functions(
except SyntaxError:
malformed_paths_count += 1
continue

if blocklist_funcs:
functions_tmp = []
for function in _functions:
Expand Down Expand Up @@ -507,6 +564,7 @@ def filter_functions(
f"{non_modules_removed_count} function{'s' if non_modules_removed_count != 1 else ''} outside module-root": non_modules_removed_count,
f"{ignore_paths_removed_count} file{'s' if ignore_paths_removed_count != 1 else ''} from ignored paths": ignore_paths_removed_count,
f"{submodule_ignored_paths_count} file{'s' if submodule_ignored_paths_count != 1 else ''} from ignored submodules": submodule_ignored_paths_count,
f"{already_optimized_count} already optimized function{'s' if already_optimized_count != 1 else ''}": already_optimized_count,
f"{blocklist_funcs_removed_count} function{'s' if blocklist_funcs_removed_count != 1 else ''} as previously optimized": blocklist_funcs_removed_count,
f"{previous_checkpoint_functions_removed_count} function{'s' if previous_checkpoint_functions_removed_count != 1 else ''} skipped from checkpoint": previous_checkpoint_functions_removed_count,
}
Expand Down
Loading
Loading