Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3cbd6b7
feat(optimizer): Implement targeted formatting (CF-637)
zomglings May 14, 2025
a10600c
Fixed changes to the FunctionOptimizer
zomglings May 14, 2025
bcba527
CODEFLASH_DISABLE_TELEMETRY environment variable can be set to disabl…
zomglings May 14, 2025
36da640
Added bubble sort implementation with bad formatting in non-optimized…
zomglings May 14, 2025
85fd3c0
Added a file containing a bubble sort method in a class
zomglings May 14, 2025
ce33708
Merge branch 'main' into targeted-formatting
zomglings May 14, 2025
2c40018
Added "scratch/" directory to .gitignore
zomglings May 14, 2025
82b9d41
Cleaned up the import sorting code in FunctionOptimizer
zomglings May 15, 2025
ce87832
Added test for sort_imports_in_place
zomglings May 15, 2025
0373bfa
Reverted changes to optimizer and formatter on targeted-formatting br…
zomglings May 16, 2025
81dbb33
Merge branch 'main' into targeted-formatting-cst-based
zomglings May 16, 2025
9345d80
Started work on targeted formatting using the CST
zomglings May 16, 2025
bfc8423
TODO
zomglings May 16, 2025
91fe6a7
Updated implementation of FunctionOptimizer.reformat_code_and_helpers
zomglings May 16, 2025
b89622c
Fixed a few bugs in reformat_code_and_helpers
zomglings May 16, 2025
5a9265c
More codeposition bugs
zomglings May 16, 2025
b903d1a
Issue with splicing
zomglings May 16, 2025
16ca27e
Fixing more bugs, testing live...
zomglings May 16, 2025
59e3667
Got it functional
zomglings May 16, 2025
c3b8063
Do not recalculate code_context when reformatting
zomglings May 16, 2025
6dd72cf
Correct calculation of all preexisting "function" symbols for formatt…
zomglings May 17, 2025
eae756a
Clarified docstring for get_modification_code_ranges
zomglings May 17, 2025
05817f9
removed xylophone
zomglings May 17, 2025
9efdc21
Added test for get_modification_code_ranges.
zomglings May 19, 2025
b6baf05
Merge branch 'main' into targeted-formatting-cst-based
zomglings May 21, 2025
9face63
"ruff check --fix"
zomglings May 21, 2025
0615f55
Fixed some more ruff check issues
zomglings May 21, 2025
af4df4a
"ruff format"
zomglings May 21, 2025
cf4a665
more fixes for "ruff check"...
zomglings May 21, 2025
d43862e
ruff format...
zomglings May 21, 2025
6ec5ef0
That should be hte last of the ruff stuff
zomglings May 21, 2025
a468ba7
Added a test for FunctionOptimizer.reformat_code_and_helpers
zomglings May 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,5 @@ fabric.properties

# Mac
.DS_Store

scratch/
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import sys


def lol():
print( "lol" )









class BubbleSorter:
def __init__(self, x=0):
self.x = x

def lol(self):
print( "lol" )








def sorter(self, arr):
print("codeflash stdout : BubbleSorter.sorter() called")
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
print("stderr test", file=sys.stderr)
return arr
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
def lol():
print( "lol" )







def sorter(arr):
print("codeflash stdout: Sorting list")
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
print(f"result: {arr}")
return arr
12 changes: 11 additions & 1 deletion codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def normalize_code(code: str) -> str:


class OptimFunctionCollector(cst.CSTVisitor):
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider, cst.metadata.PositionProvider)

def __init__(
self,
Expand All @@ -52,8 +52,11 @@ def __init__(
self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
self.current_class = None
self.modified_init_functions: dict[str, cst.FunctionDef] = {}
self.modification_code_range_lines: list[tuple[int, int]] = []

def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
modification = True

if (self.current_class, node.name.value) in self.function_names:
self.modified_functions[(self.current_class, node.name.value)] = node
elif self.current_class and node.name.value == "__init__":
Expand All @@ -64,6 +67,13 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
and self.current_class is None
):
self.new_functions.append(node)
else:
modification = False

if modification:
pos = self.get_metadata(cst.metadata.PositionProvider, node)
self.modification_code_range_lines.append((pos.start.line, pos.end.line))

return False

def visit_ClassDef(self, node: cst.ClassDef) -> bool:
Expand Down
36 changes: 36 additions & 0 deletions codeflash/code_utils/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
from typing import TYPE_CHECKING

import isort
import libcst as cst

from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.code_replacer import OptimFunctionCollector

if TYPE_CHECKING:
from pathlib import Path

from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent, FunctionSource


def format_code(formatter_cmds: list[str], path: Path) -> str:
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
Expand Down Expand Up @@ -55,3 +60,34 @@ def sort_imports(code: str) -> str:
return code # Fall back to original code if isort fails

return sorted_code


def get_modification_code_ranges(
modified_code: str,
fto: FunctionToOptimize,
preexisting_functions: set[tuple[str, tuple[FunctionParent, ...]]],
helper_functions: list[FunctionSource],
) -> list[tuple[int, int]]:
"""Return the starting and ending line numbers of modified and new functions in a file with edits."""
modified_functions = set()
modified_functions.add(fto.qualified_name)
for helper_function in helper_functions:
if helper_function.jedi_definition.type != "class":
modified_functions.add(helper_function.qualified_name)

parsed_function_names = set()
for original_function_name in modified_functions:
if original_function_name.count(".") == 0:
class_name, function_name = None, original_function_name
elif original_function_name.count(".") == 1:
class_name, function_name = original_function_name.split(".")
else:
msg = f"Unable to find {original_function_name}. Returning unchanged source code."
logger.error(msg)
continue
parsed_function_names.add((class_name, function_name))

module = cst.metadata.MetadataWrapper(cst.parse_module(modified_code))
visitor = OptimFunctionCollector(preexisting_functions, parsed_function_names)
module.visit(visitor)
return visitor.modification_code_range_lines
19 changes: 12 additions & 7 deletions codeflash/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
solved problem, please reach out to us at [email protected]. We're hiring!
"""

import os
from pathlib import Path

from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
Expand All @@ -22,25 +23,29 @@ def main() -> None:
CODEFLASH_LOGO, panel_args={"title": "https://codeflash.ai", "expand": False}, text_args={"style": "bold gold3"}
)
args = parse_args()

disable_telemetry_env = os.environ.get("CODEFLASH_DISABLE_TELEMETRY", "").lower() in {"true", "t", "1", "yes", "y"}

if args.command:
if args.config_file and Path.exists(args.config_file):
disable_telemetry = disable_telemetry_env
if (not disable_telemetry) and args.config_file and Path.exists(args.config_file):
pyproject_config, _ = parse_config_file(args.config_file)
disable_telemetry = pyproject_config.get("disable_telemetry", False)
else:
disable_telemetry = False
init_sentry(not disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(not disable_telemetry)
args.func()
elif args.verify_setup:
args = process_pyproject_config(args)
init_sentry(not args.disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(not args.disable_telemetry)
disable_telemetry = args.disable_telemetry or disable_telemetry_env
init_sentry(not disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(not disable_telemetry)
ask_run_end_to_end_test(args)
else:
args = process_pyproject_config(args)
args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args)
init_sentry(not args.disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(not args.disable_telemetry)
disable_telemetry = args.disable_telemetry or disable_telemetry_env
init_sentry(not disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(not disable_telemetry)
optimizer.run_with_args(args)


Expand Down
79 changes: 65 additions & 14 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from codeflash.benchmarking.utils import process_benchmark_data
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_extractor import find_preexisting_objects
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.code_utils.code_utils import (
cleanup_paths,
Expand All @@ -35,7 +36,7 @@
N_TESTS_TO_GENERATE,
TOTAL_LOOPING_TIME,
)
from codeflash.code_utils.formatter import format_code, sort_imports
from codeflash.code_utils.formatter import format_code, get_modification_code_ranges, sort_imports
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.code_utils.line_profile_utils import add_decorator_imports
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
Expand Down Expand Up @@ -77,10 +78,14 @@

from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.either import Result
from codeflash.models.models import BenchmarkKey, CoverageData, FunctionSource, OptimizedCandidate
from codeflash.models.models import BenchmarkKey, CoverageData, FunctionParent, FunctionSource, OptimizedCandidate
from codeflash.verification.verification_utils import TestConfig


class FunctionOptimizerError(Exception):
pass


class FunctionOptimizer:
def __init__(
self,
Expand Down Expand Up @@ -296,12 +301,24 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911

self.log_successful_optimization(explanation, generated_tests, exp_type)

preexisting_functions_by_filepath: dict[Path, list[str]] = {}
filepaths_to_inspect = [
self.function_to_optimize.file_path,
*list({helper.file_path for helper in code_context.helper_functions}),
]
for filepath in filepaths_to_inspect:
source_code = filepath.read_text(encoding="utf8")
preexisting_functions_by_filepath[filepath] = find_preexisting_objects(source_code)

self.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=best_optimization.candidate.source_code
)

new_code, new_helper_code = self.reformat_code_and_helpers(
code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code
preexisting_functions_by_filepath,
code_context.helper_functions,
explanation.file_path,
self.function_to_optimize_source_code,
)

existing_tests = existing_tests_source_for(
Expand Down Expand Up @@ -587,25 +604,59 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path,
f.write(helper_code)

def reformat_code_and_helpers(
self, helper_functions: list[FunctionSource], path: Path, original_code: str
self,
preexisting_functions_by_filepath: dict[Path, set[tuple[str, tuple[FunctionParent, ...]]]],
helper_functions: list[FunctionSource],
fto_path: Path,
original_code: str,
) -> tuple[str, dict[Path, str]]:
should_sort_imports = not self.args.disable_imports_sorting
if should_sort_imports and isort.code(original_code) != original_code:
should_sort_imports = False

new_code = format_code(self.args.formatter_cmds, path)
if should_sort_imports:
new_code = sort_imports(new_code)

paths = [fto_path, *list({hf.file_path for hf in helper_functions})]
new_target_code = None
new_helper_code: dict[Path, str] = {}
helper_functions_paths = {hf.file_path for hf in helper_functions}
for module_abspath in helper_functions_paths:
formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath)
for i, path in enumerate(paths):
unformatted_code = path.read_text(encoding="utf8")
code_ranges_unformatted = get_modification_code_ranges(
unformatted_code, self.function_to_optimize, preexisting_functions_by_filepath[path], helper_functions
)
formatted_code = format_code(self.args.formatter_cmds, path)
# Note: We do not need to refresh the code_context because we only use it to refer to names of original
# functions (even before optimization was applied) and filepaths, none of which is changing.
code_ranges_formatted = get_modification_code_ranges(
formatted_code, self.function_to_optimize, preexisting_functions_by_filepath[path], helper_functions
)

if len(code_ranges_formatted) != len(code_ranges_unformatted):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Can you help me understand the case on Why these have to be equal always?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a defensive measure. The function to optimize and helpers are non-overlapping blocks of code, and should not become overlapping due to formatting. If this has happened, something has gone wrong in the formatting step or in the extraction step (which identifies function to optimize and helpers).

raise FunctionOptimizerError("Formatting had unexpected effects on code ranges")

# It is important to sort in descending order so that the index arithmetic remains simple as we modify new_code
code_ranges_unformatted.sort(key=lambda r: r[0], reverse=True)
code_ranges_formatted.sort(key=lambda r: r[0], reverse=True)
formatted_code_lines = formatted_code.split("\n")
new_code_lines = unformatted_code.split("\n")
for range_0, range_1 in zip(code_ranges_unformatted, code_ranges_formatted):
range_0_0, range_0_1 = range_0
range_1_0, range_1_1 = range_1
new_code_lines = (
new_code_lines[:range_0_0]
+ formatted_code_lines[range_1_0 : range_1_1 + 1]
+ new_code_lines[range_0_1 + 1 :]
)
new_code = "\n".join(new_code_lines)
path.write_text(new_code, encoding="utf8")

if should_sort_imports:
formatted_helper_code = sort_imports(formatted_helper_code)
new_helper_code[module_abspath] = formatted_helper_code
new_code = sort_imports(new_code)

if i == 0:
new_target_code = new_code
else:
new_helper_code[path] = new_code

return new_code, new_helper_code
return new_target_code, new_helper_code

def replace_function_and_helpers_with_optimized_code(
self, code_context: CodeOptimizationContext, optimized_code: str
Expand Down
17 changes: 16 additions & 1 deletion tests/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import tempfile
from pathlib import Path

from jedi.api.classes import Name
import pytest

from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.formatter import format_code, sort_imports
from codeflash.code_utils.formatter import format_code, get_modification_code_ranges, sort_imports
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionSource


def test_remove_duplicate_imports():
Expand Down Expand Up @@ -209,3 +212,15 @@ def foo():
tmp_path = tmp.name
with pytest.raises(FileNotFoundError):
format_code(formatter_cmds=["exit 1"], path=Path(tmp_path))

def test_get_modification_code_ranges_self_contained_fto():
modified_code = """
def hello(name):
print(f"Hello, {{name}}")
"""

fto = FunctionToOptimize(function_name="hello", file_path=Path("hello.py"), parents=[])
code_ranges = get_modification_code_ranges(modified_code, fto, set(), [])

assert len(code_ranges) == 1
assert code_ranges[0] == (2, 3)
Loading
Loading