Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions codeflash/code_utils/coverage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@

def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]:
"""Extract the single dependent function from the code context excluding the main function."""
ast_tree = ast.parse(code_context.testgen_context_code)

dependent_functions = {
node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
}
dependent_functions = set()
for code_string in code_context.testgen_context.code_strings:
ast_tree = ast.parse(code_string.code)
dependent_functions.update(
{node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))}
)

if main_function in dependent_functions:
dependent_functions.discard(main_function)
Expand Down
18 changes: 9 additions & 9 deletions codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,32 +114,32 @@ def get_code_optimization_context(
read_only_context_code = ""

# Extract code context for testgen
testgen_code_markdown = extract_code_string_context_from_files(
testgen_context = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=False,
code_context_type=CodeContextType.TESTGEN,
)
testgen_context_code = testgen_code_markdown.code
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
if testgen_context_code_tokens > testgen_token_limit:
testgen_code_markdown = extract_code_string_context_from_files(
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
testgen_context = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=True,
code_context_type=CodeContextType.TESTGEN,
)
testgen_context_code = testgen_code_markdown.code
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
if testgen_context_code_tokens > testgen_token_limit:
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
code_hash_context = hashing_code_context.markdown
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()

return CodeOptimizationContext(
testgen_context_code=testgen_context_code,
testgen_context=testgen_context,
read_writable_code=final_read_writable_code,
read_only_context_code=read_only_context_code,
hashing_code_context=code_hash_context,
Expand Down
4 changes: 2 additions & 2 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class CodeString(BaseModel):


def get_code_block_splitter(file_path: Path) -> str:
return f"# file: {file_path}"
return f"# file: {file_path.as_posix()}"


markdown_pattern = re.compile(r"```python:([^\n]+)\n(.*?)\n```", re.DOTALL)
Expand Down Expand Up @@ -254,7 +254,7 @@ def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown:


class CodeOptimizationContext(BaseModel):
testgen_context_code: str = ""
testgen_context: CodeStringsMarkdown
read_writable_code: CodeStringsMarkdown
read_only_context_code: str = ""
hashing_code_context: str = ""
Expand Down
9 changes: 4 additions & 5 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def generate_and_instrument_tests(
revert_to_print=bool(get_pr_number()),
):
generated_results = self.generate_tests_and_optimizations(
testgen_context_code=code_context.testgen_context_code,
testgen_context=code_context.testgen_context,
read_writable_code=code_context.read_writable_code,
read_only_context_code=code_context.read_only_context_code,
helper_functions=code_context.helper_functions,
Expand Down Expand Up @@ -345,7 +345,6 @@ def generate_and_instrument_tests(
logger.info(f"Generated test {i + 1}/{count_tests}:")
code_print(generated_test.generated_original_test_source, file_name=f"test_{i + 1}.py")
if concolic_test_str:
# no concolic tests in lsp mode
logger.info(f"Generated test {count_tests}/{count_tests}:")
code_print(concolic_test_str)

Expand Down Expand Up @@ -972,7 +971,7 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:

return Success(
CodeOptimizationContext(
testgen_context_code=new_code_ctx.testgen_context_code,
testgen_context=new_code_ctx.testgen_context,
read_writable_code=new_code_ctx.read_writable_code,
read_only_context_code=new_code_ctx.read_only_context_code,
hashing_code_context=new_code_ctx.hashing_code_context,
Expand Down Expand Up @@ -1079,7 +1078,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio

def generate_tests_and_optimizations(
self,
testgen_context_code: str,
testgen_context: CodeStringsMarkdown,
read_writable_code: CodeStringsMarkdown,
read_only_context_code: str,
helper_functions: list[FunctionSource],
Expand All @@ -1093,7 +1092,7 @@ def generate_tests_and_optimizations(
# Submit the test generation task as future
future_tests = self.submit_test_generation_tasks(
self.executor,
testgen_context_code,
testgen_context.markdown,
[definition.fully_qualified_name for definition in helper_functions],
generated_test_paths,
generated_perf_test_paths,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_code_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,8 @@ def main_method(self):


def test_code_replacement10() -> None:
get_code_output = """from __future__ import annotations
get_code_output = """# file: test_code_replacement.py
from __future__ import annotations

class HelperClass:
def __init__(self, name):
Expand Down Expand Up @@ -828,7 +829,7 @@ def main_method(self):
)
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
code_context = func_optimizer.get_code_optimization_context().unwrap()
assert code_context.testgen_context_code.rstrip() == get_code_output.rstrip()
assert code_context.testgen_context.flat.rstrip() == get_code_output.rstrip()


def test_code_replacement11() -> None:
Expand Down
34 changes: 21 additions & 13 deletions tests/test_code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from codeflash.code_utils.concolic_utils import clean_concolic_tests
from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files
from codeflash.models.models import CodeStringsMarkdown


@pytest.fixture
Expand Down Expand Up @@ -382,69 +383,76 @@ def mock_code_context():
def test_extract_dependent_function_sync_and_async(mock_code_context):
"""Test extract_dependent_function with both sync and async functions."""
# Test sync function extraction
mock_code_context.testgen_context_code = """
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
def main_function():
pass

def helper_function():
pass
"""
```
""")
assert extract_dependent_function("main_function", mock_code_context) == "helper_function"

# Test async function extraction
mock_code_context.testgen_context_code = """
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
def main_function():
pass

async def async_helper_function():
pass
"""
```
""")

assert extract_dependent_function("main_function", mock_code_context) == "async_helper_function"


def test_extract_dependent_function_edge_cases(mock_code_context):
"""Test extract_dependent_function edge cases."""
# No dependent functions
mock_code_context.testgen_context_code = """
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
def main_function():
pass
"""
```
""")
assert extract_dependent_function("main_function", mock_code_context) is False

# Multiple dependent functions
mock_code_context.testgen_context_code = """
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
def main_function():
pass

def helper1():
pass

async def helper2():
pass
"""
```
""")
assert extract_dependent_function("main_function", mock_code_context) is False


def test_extract_dependent_function_mixed_scenarios(mock_code_context):
"""Test extract_dependent_function with mixed sync/async scenarios."""
# Async main with sync helper
mock_code_context.testgen_context_code = """
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
async def async_main():
pass

def sync_helper():
pass
"""
```
""")
assert extract_dependent_function("async_main", mock_code_context) == "sync_helper"

# Only async functions
mock_code_context.testgen_context_code = """
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
async def async_main():
pass

async def async_helper():
pass
"""
```
""")

assert extract_dependent_function("async_main", mock_code_context) == "async_helper"


Expand Down
10 changes: 6 additions & 4 deletions tests/test_function_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,9 @@ def test_class_method_dependencies() -> None:
)
assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil"
assert (
code_context.testgen_context_code
== """from collections import defaultdict
code_context.testgen_context.flat
== """# file: test_function_dependencies.py
from collections import defaultdict

class Graph:
def __init__(self, vertices):
Expand Down Expand Up @@ -220,8 +221,9 @@ def test_recursive_function_context() -> None:
assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3"
assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive"
assert (
code_context.testgen_context_code
== """class C:
code_context.testgen_context.flat
== """# file: test_function_dependencies.py
class C:
def calculate_something_3(self, num):
return num + 1

Expand Down
21 changes: 12 additions & 9 deletions tests/test_get_helper_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.either import is_successful
from codeflash.models.models import FunctionParent
from codeflash.models.models import FunctionParent, get_code_block_splitter
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.optimization.optimizer import Optimizer
from codeflash.verification.verification_utils import TestConfig
Expand Down Expand Up @@ -242,8 +242,9 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
code_context = ctx_result.unwrap()
assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call"
assert (
code_context.testgen_context_code
== f'''_P = ParamSpec("_P")
code_context.testgen_context.flat
== f'''# file: {file_path.relative_to(project_root_path)}
_P = ParamSpec("_P")
_KEY_T = TypeVar("_KEY_T")
_STORE_T = TypeVar("_STORE_T")
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
Expand Down Expand Up @@ -395,10 +396,11 @@ def test_bubble_sort_deps() -> None:
function_to_optimize = FunctionToOptimize(
function_name="sorter_deps", file_path=file_path, parents=[], starting_line=None, ending_line=None
)
project_root = file_path.parent.parent.resolve()
test_config = TestConfig(
tests_root=str(file_path.parent / "tests"),
tests_project_rootdir=file_path.parent.resolve(),
project_root_path=file_path.parent.parent.resolve(),
project_root_path=project_root,
test_framework="pytest",
pytest_cmd="pytest",
)
Expand All @@ -410,19 +412,20 @@ def test_bubble_sort_deps() -> None:
pytest.fail()
code_context = ctx_result.unwrap()
assert (
code_context.testgen_context_code
== """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap

code_context.testgen_context.flat
== f"""{get_code_block_splitter(Path("code_to_optimize/bubble_sort_dep1_helper.py"))}
def dep1_comparer(arr, j: int) -> bool:
return arr[j] > arr[j + 1]

{get_code_block_splitter(Path("code_to_optimize/bubble_sort_dep2_swap.py"))}
def dep2_swap(arr, j):
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp


{get_code_block_splitter(Path("code_to_optimize/bubble_sort_deps.py"))}
from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap

def sorter_deps(arr):
for i in range(len(arr)):
Expand Down
Loading