Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
1f7124a
WIP
misrasaurabh1 Apr 28, 2025
41378a0
WIP
misrasaurabh1 Apr 28, 2025
e9746c9
batch code hash check
dasarchan Jun 2, 2025
5760316
implemented hash check into filter_functions
dasarchan Jun 3, 2025
905b1a0
Merge branch 'main' into dont-optimize-repeatedly-gh-actions
Jun 3, 2025
2367160
removed prints, added cfapi.py func
dasarchan Jun 5, 2025
f2733b3
Merge branch 'dont-optimize-repeatedly-gh-actions' of https://github.…
dasarchan Jun 5, 2025
c1fb089
removed unused import
dasarchan Jun 5, 2025
3443404
Merge branch 'main' into dont-optimize-repeatedly-gh-actions
misrasaurabh1 Jun 6, 2025
eb3d305
fix no git error
misrasaurabh1 Jun 6, 2025
c862b4d
add low prob of repeating optimization
dasarchan Jun 6, 2025
96ee580
changes to cli for code context hash
dasarchan Jun 7, 2025
87fe086
update the cli
misrasaurabh1 Jun 7, 2025
4cb823e
added separate write route, changed return format for api route
dasarchan Jun 7, 2025
1cc39e3
merge
dasarchan Jun 7, 2025
dd8dceb
removed empty test file
dasarchan Jun 7, 2025
5989b26
updates
dasarchan Jun 7, 2025
5c0a028
Add a first version of hashing code context
misrasaurabh1 Jun 8, 2025
2686682
Might work?
misrasaurabh1 Jun 8, 2025
4f39794
get it working
misrasaurabh1 Jun 8, 2025
50f4c33
10% chance of optimizing again
misrasaurabh1 Jun 8, 2025
81f96ed
Merge branch 'main' into dont-optimize-repeatedly-gh-actions
misrasaurabh1 Jun 8, 2025
c856f1e
fix a bug
misrasaurabh1 Jun 8, 2025
b48ed5c
ruff fix
misrasaurabh1 Jun 8, 2025
9e14cfe
fix bugs with docstring removal
misrasaurabh1 Jun 8, 2025
5d4870f
fix a type
misrasaurabh1 Jun 8, 2025
2c1314d
fix more tests
misrasaurabh1 Jun 8, 2025
32a8001
fix types for python 3.9
misrasaurabh1 Jun 8, 2025
e2f1ba0
clearer message
misrasaurabh1 Jun 8, 2025
f6b3275
fix mypy types
misrasaurabh1 Jun 8, 2025
6ed9387
add more tests
misrasaurabh1 Jun 8, 2025
be1ef9b
fix for test
misrasaurabh1 Jun 8, 2025
9137921
double the context length
misrasaurabh1 Jun 8, 2025
797cba3
ruff revert
misrasaurabh1 Jun 8, 2025
d0f84f6
improve some github actions logging
misrasaurabh1 Jun 8, 2025
2d62171
some refactor
misrasaurabh1 Jun 9, 2025
226acd7
remove unncessary line
misrasaurabh1 Jun 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions codeflash/api/cfapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import hashlib
import json
import os
import sys
Expand All @@ -14,6 +15,7 @@
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
from codeflash.code_utils.git_utils import get_repo_owner_and_name
from codeflash.models.models import CodeOptimizationContext
from codeflash.version import __version__

if TYPE_CHECKING:
Expand Down Expand Up @@ -191,3 +193,20 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]:
return {}

return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()}


def is_function_being_optimized_again(code_context: CodeOptimizationContext) -> bool:
"""Check if the function being optimized is being optimized again."""
pr_number = get_pr_number()
if pr_number is None:
# Only want to do this check during GH Actions
return False
owner, repo = get_repo_owner_and_name()
# TODO: Add file paths
rw_context_hash = hashlib.sha256(str(code_context.read_writable_code).encode()).hexdigest()

payload = {"owner": owner, "repo": repo, "pullNumber": pr_number, "code_hash": rw_context_hash}
response = make_cfapi_request(endpoint="/is-function-being-optimized-again", method="POST", payload=payload)
if not response.ok or response.text != "true":
logger.error(f"Error: {response.text}")
return False
183 changes: 160 additions & 23 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import ast
import hashlib
import os
import random
import warnings
Expand All @@ -14,14 +15,15 @@
import libcst as cst
from pydantic.dataclasses import dataclass

from codeflash.api.cfapi import get_blocklisted_functions
from codeflash.api.cfapi import get_blocklisted_functions, make_cfapi_request
from codeflash.cli_cmds.console import DEBUG_MODE, console, logger
from codeflash.code_utils.code_utils import (
is_class_defined_in_file,
module_name_from_file_path,
path_belongs_to_site_packages,
)
from codeflash.code_utils.git_utils import get_git_diff
from codeflash.code_utils.env_utils import get_pr_number
from codeflash.code_utils.git_utils import get_git_diff, get_repo_owner_and_name
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.models.models import FunctionParent
Expand Down Expand Up @@ -144,6 +146,45 @@ def qualified_name(self) -> str:
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"

def get_code_context_hash(self) -> str:
"""Generate a SHA-256 hash representing the code context of this function.

This hash includes the function's code content, file path, and qualified name
to uniquely identify the function for optimization tracking.
"""
try:
with open(self.file_path, 'r', encoding='utf-8') as f:
file_content = f.read()

# Extract the function's code content
lines = file_content.splitlines()
if self.starting_line is not None and self.ending_line is not None:
# Use line numbers if available (1-indexed to 0-indexed)
function_content = '\n'.join(lines[self.starting_line - 1:self.ending_line])
else:
# Fallback: use the entire file content if line numbers aren't available
function_content = file_content

# Create a context string that includes:
# - File path (relative to make it portable)
# - Qualified function name
# - Function code content
context_parts = [
str(self.file_path.name), # Just filename for portability
self.qualified_name,
function_content.strip()
]

context_string = '\n---\n'.join(context_parts)

# Generate SHA-256 hash
return hashlib.sha256(context_string.encode('utf-8')).hexdigest()

except (OSError, IOError) as e:
logger.warning(f"Could not read file {self.file_path} for hashing: {e}")
# Fallback hash using available metadata
fallback_string = f"{self.file_path.name}:{self.qualified_name}"
return hashlib.sha256(fallback_string.encode('utf-8')).hexdigest()

def get_functions_to_optimize(
optimize_all: str | None,
Expand Down Expand Up @@ -187,7 +228,7 @@ def get_functions_to_optimize(
found_function = None
for fn in functions.get(file, []):
if only_function_name == fn.function_name and (
class_name is None or class_name == fn.top_level_parent_name
class_name is None or class_name == fn.top_level_parent_name
):
found_function = fn
if found_function is None:
Expand Down Expand Up @@ -266,7 +307,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt


def get_all_replay_test_functions(
replay_test: Path, test_cfg: TestConfig, project_root_path: Path
replay_test: Path, test_cfg: TestConfig, project_root_path: Path
) -> dict[Path, list[FunctionToOptimize]]:
function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
# Get the absolute file paths for each function, excluding class name if present
Expand All @@ -281,7 +322,7 @@ def get_all_replay_test_functions(
class_name = (
module_path_parts[-1]
if module_path_parts
and is_class_defined_in_file(
and is_class_defined_in_file(
module_path_parts[-1], Path(project_root_path, *module_path_parts[:-1]).with_suffix(".py")
)
else None
Expand Down Expand Up @@ -333,7 +374,8 @@ def ignored_submodule_paths(module_root: str) -> list[str]:

class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
def __init__(
self, file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None
self, file_name: Path, function_or_method_name: str, class_name: str | None = None,
line_no: int | None = None
) -> None:
self.file_name = file_name
self.class_name = class_name
Expand Down Expand Up @@ -364,13 +406,13 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name:
self.is_top_level = True
if any(
isinstance(decorator, ast.Name) and decorator.id == "classmethod"
for decorator in body_node.decorator_list
isinstance(decorator, ast.Name) and decorator.id == "classmethod"
for decorator in body_node.decorator_list
):
self.is_classmethod = True
elif any(
isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
for decorator in body_node.decorator_list
isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
for decorator in body_node.decorator_list
):
self.is_staticmethod = True
return
Expand All @@ -379,13 +421,13 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
# This way, if we don't have the class name, we can still find the static method
for body_node in node.body:
if (
isinstance(body_node, ast.FunctionDef)
and body_node.name == self.function_name
and body_node.lineno in {self.line_no, self.line_no + 1}
and any(
isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
for decorator in body_node.decorator_list
)
isinstance(body_node, ast.FunctionDef)
and body_node.name == self.function_name
and body_node.lineno in {self.line_no, self.line_no + 1}
and any(
isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
for decorator in body_node.decorator_list
)
):
self.is_staticmethod = True
self.is_top_level = True
Expand Down Expand Up @@ -417,6 +459,82 @@ def inspect_top_level_functions_or_methods(
)


def check_optimization_status(
functions_by_file: dict[Path, list[FunctionToOptimize]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: What will be the scenario where the function is already an optimized code by CF?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now it would try to reoptimize it, actually - @misrasaurabh1 what's the desired behavior here

owner: str,
repo: str,
pr_number: int
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
"""Check which functions have already been optimized and filter them out.

This function calls the optimization API to:
1. Check which functions are already optimized
2. Log new function hashes to the database
3. Return only functions that need optimization

Args:
functions_by_file: Dictionary mapping file paths to lists of functions
owner: Repository owner
repo: Repository name
pr_number: Pull request number

Returns:
Tuple of (filtered_functions_dict, remaining_count)
"""
logger.info("entering function")
# Build the code_contexts dictionary for the API call
code_contexts = {}
path_to_function_map = {}

for file_path, functions in functions_by_file.items():
for func in functions:
func_hash = func.get_code_context_hash()
# Use a unique path identifier that includes function info
path_key = f"{file_path}:{func.qualified_name}"
code_contexts[path_key] = func_hash
path_to_function_map[path_key] = (file_path, func)

if not code_contexts:
return {}, 0

try:
# Call the optimization check API
logger.info("Checking status")
response = make_cfapi_request(
"/is-already-optimized",
"POST",
{
"owner": owner,
"repo": repo,
"pr_number": pr_number,
"code_contexts": code_contexts
}
)
response.raise_for_status()
result = response.json()
already_optimized_paths = set(result.get("already_optimized_paths", []))

logger.info(f"Found {len(already_optimized_paths)} already optimized functions")

# Filter out already optimized functions
filtered_functions = defaultdict(list)
remaining_count = 0

for path_key, (file_path, func) in path_to_function_map.items():
if path_key not in already_optimized_paths:
filtered_functions[file_path].append(func)
remaining_count += 1

return dict(filtered_functions), remaining_count

except Exception as e:
logger.warning(f"Failed to check optimization status: {e}")
logger.info("Proceeding with all functions (optimization check failed)")
# Return all functions if API call fails
total_count = sum(len(funcs) for funcs in functions_by_file.values())
return functions_by_file, total_count


def filter_functions(
modified_functions: dict[Path, list[FunctionToOptimize]],
tests_root: Path,
Expand All @@ -426,6 +544,7 @@ def filter_functions(
previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None,
disable_logs: bool = False, # noqa: FBT001, FBT002
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
logger.info("filtering functions boogaloo")
blocklist_funcs = get_blocklisted_functions()
logger.debug(f"Blocklisted functions: {blocklist_funcs}")
# Remove any function that we don't want to optimize
Expand All @@ -445,6 +564,7 @@ def filter_functions(
previous_checkpoint_functions_removed_count: int = 0
tests_root_str = str(tests_root)
module_root_str = str(module_root)

# We desperately need Python 3.10+ only support to make this code readable with structural pattern matching
for file_path_path, functions in modified_functions.items():
_functions = functions
Expand All @@ -453,12 +573,12 @@ def filter_functions(
test_functions_removed_count += len(_functions)
continue
if file_path in ignore_paths or any(
file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths
file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths
):
ignore_paths_removed_count += 1
continue
if file_path in submodule_paths or any(
file_path.startswith(str(submodule_path) + os.sep) for submodule_path in submodule_paths
file_path.startswith(str(submodule_path) + os.sep) for submodule_path in submodule_paths
):
submodule_ignored_paths_count += 1
continue
Expand Down Expand Up @@ -497,6 +617,22 @@ def filter_functions(
filtered_modified_functions[file_path] = _functions
functions_count += len(_functions)

# Convert to Path keys for optimization check
path_based_functions = {Path(k): v for k, v in filtered_modified_functions.items() if v}

# Check optimization status if repository info is provided
already_optimized_count = 0
repository = git.Repo(Path.cwd(), search_parent_directories=True)
owner, repo = get_repo_owner_and_name(repository)
pr_number = get_pr_number()
print(owner, repo, pr_number)
if owner and repo and pr_number is not None:
path_based_functions, functions_count = check_optimization_status(
path_based_functions, owner, repo, pr_number
)
initial_count = sum(len(funcs) for funcs in filtered_modified_functions.values())
already_optimized_count = initial_count - functions_count

if not disable_logs:
log_info = {
f"{test_functions_removed_count} test function{'s' if test_functions_removed_count != 1 else ''}": test_functions_removed_count,
Expand All @@ -505,6 +641,7 @@ def filter_functions(
f"{non_modules_removed_count} function{'s' if non_modules_removed_count != 1 else ''} outside module-root": non_modules_removed_count,
f"{ignore_paths_removed_count} file{'s' if ignore_paths_removed_count != 1 else ''} from ignored paths": ignore_paths_removed_count,
f"{submodule_ignored_paths_count} file{'s' if submodule_ignored_paths_count != 1 else ''} from ignored submodules": submodule_ignored_paths_count,
f"{already_optimized_count} already optimized function{'s' if already_optimized_count != 1 else ''}": already_optimized_count,
f"{blocklist_funcs_removed_count} function{'s' if blocklist_funcs_removed_count != 1 else ''} as previously optimized": blocklist_funcs_removed_count,
f"{previous_checkpoint_functions_removed_count} function{'s' if previous_checkpoint_functions_removed_count != 1 else ''} skipped from checkpoint": previous_checkpoint_functions_removed_count,
}
Expand All @@ -513,7 +650,7 @@ def filter_functions(
logger.info(f"Ignoring: {log_string}")
console.rule()

return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count
return path_based_functions, functions_count


def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list[Path], module_root: Path) -> bool:
Expand All @@ -533,8 +670,8 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
if submodule_paths is None:
submodule_paths = ignored_submodule_paths(module_root)
return not (
file_path in submodule_paths
or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths)
file_path in submodule_paths
or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths)
)


Expand All @@ -543,4 +680,4 @@ def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef)


def function_is_a_property(function_node: FunctionDef | AsyncFunctionDef) -> bool:
return any(isinstance(node, ast.Name) and node.id == "property" for node in function_node.decorator_list)
return any(isinstance(node, ast.Name) and node.id == "property" for node in function_node.decorator_list)
4 changes: 4 additions & 0 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from rich.tree import Tree

from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
from codeflash.api.cfapi import is_function_being_optimized_again
from codeflash.benchmarking.utils import process_benchmark_data
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
from codeflash.code_utils import env_utils
Expand Down Expand Up @@ -144,6 +145,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
if has_any_async_functions(code_context.read_writable_code):
return Failure("Codeflash does not support async functions in the code to optimize.")

if is_function_being_optimized_again(code_context=code_context):
return Failure("This code has already been optimized earlier")

code_print(code_context.read_writable_code)
generated_test_paths = [
get_test_file_path(
Expand Down
Loading