diff --git a/codeflash/code_utils/coverage_utils.py b/codeflash/code_utils/coverage_utils.py index 8dd5e6c32..ed3d277a4 100644 --- a/codeflash/code_utils/coverage_utils.py +++ b/codeflash/code_utils/coverage_utils.py @@ -12,11 +12,12 @@ def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]: """Extract the single dependent function from the code context excluding the main function.""" - ast_tree = ast.parse(code_context.testgen_context_code) - - dependent_functions = { - node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) - } + dependent_functions = set() + for code_string in code_context.testgen_context.code_strings: + ast_tree = ast.parse(code_string.code) + dependent_functions.update( + {node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))} + ) if main_function in dependent_functions: dependent_functions.discard(main_function) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 76174877a..54fda3e16 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -114,32 +114,32 @@ def get_code_optimization_context( read_only_context_code = "" # Extract code context for testgen - testgen_code_markdown = extract_code_string_context_from_files( + testgen_context = extract_code_markdown_context_from_files( helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.TESTGEN, ) - testgen_context_code = testgen_code_markdown.code - testgen_context_code_tokens = encoded_tokens_len(testgen_context_code) - if testgen_context_code_tokens > testgen_token_limit: - testgen_code_markdown = extract_code_string_context_from_files( + testgen_markdown_code = testgen_context.markdown + testgen_code_token_length = encoded_tokens_len(testgen_markdown_code) + if testgen_code_token_length > testgen_token_limit: + testgen_context = extract_code_markdown_context_from_files( helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True, code_context_type=CodeContextType.TESTGEN, ) - testgen_context_code = testgen_code_markdown.code - testgen_context_code_tokens = encoded_tokens_len(testgen_context_code) - if testgen_context_code_tokens > testgen_token_limit: + testgen_markdown_code = testgen_context.markdown + testgen_code_token_length = encoded_tokens_len(testgen_markdown_code) + if testgen_code_token_length > testgen_token_limit: raise ValueError("Testgen code context has exceeded token limit, cannot proceed") code_hash_context = hashing_code_context.markdown code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest() return CodeOptimizationContext( - testgen_context_code=testgen_context_code, + testgen_context=testgen_context, read_writable_code=final_read_writable_code, read_only_context_code=read_only_context_code, hashing_code_context=code_hash_context, diff --git a/codeflash/models/models.py b/codeflash/models/models.py index c66775166..84179054e 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -163,7 +163,7 @@ class CodeString(BaseModel): def get_code_block_splitter(file_path: Path) -> str: - return f"# file: {file_path}" + return f"# file: {file_path.as_posix()}" markdown_pattern = re.compile(r"```python:([^\n]+)\n(.*?)\n```", re.DOTALL) @@ -254,7 +254,7 @@ def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown: class CodeOptimizationContext(BaseModel): - testgen_context_code: str = "" + testgen_context: CodeStringsMarkdown read_writable_code: CodeStringsMarkdown read_only_context_code: str = "" hashing_code_context: str = "" diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 58f0e172f..d43198e68 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -309,7 +309,7 @@ def generate_and_instrument_tests( revert_to_print=bool(get_pr_number()), ): generated_results = self.generate_tests_and_optimizations( - testgen_context_code=code_context.testgen_context_code, + testgen_context=code_context.testgen_context, read_writable_code=code_context.read_writable_code, read_only_context_code=code_context.read_only_context_code, helper_functions=code_context.helper_functions, @@ -345,7 +345,6 @@ def generate_and_instrument_tests( logger.info(f"Generated test {i + 1}/{count_tests}:") code_print(generated_test.generated_original_test_source, file_name=f"test_{i + 1}.py") if concolic_test_str: - # no concolic tests in lsp mode logger.info(f"Generated test {count_tests}/{count_tests}:") code_print(concolic_test_str) @@ -972,7 +971,7 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: return Success( CodeOptimizationContext( - testgen_context_code=new_code_ctx.testgen_context_code, + testgen_context=new_code_ctx.testgen_context, read_writable_code=new_code_ctx.read_writable_code, read_only_context_code=new_code_ctx.read_only_context_code, hashing_code_context=new_code_ctx.hashing_code_context, @@ -1079,7 +1078,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio def generate_tests_and_optimizations( self, - testgen_context_code: str, + testgen_context: CodeStringsMarkdown, read_writable_code: CodeStringsMarkdown, read_only_context_code: str, helper_functions: list[FunctionSource], @@ -1093,7 +1092,7 @@ def generate_tests_and_optimizations( # Submit the test generation task as future future_tests = self.submit_test_generation_tasks( self.executor, - testgen_context_code, + testgen_context.markdown, [definition.fully_qualified_name for definition in helper_functions], generated_test_paths, generated_perf_test_paths, diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index b04a5d64b..86e5f989d 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -798,7 +798,8 @@ def main_method(self): def test_code_replacement10() -> None: - get_code_output = """from __future__ import annotations + get_code_output = """# file: test_code_replacement.py +from __future__ import annotations class HelperClass: def __init__(self, name): @@ -828,7 +829,7 @@ def main_method(self): ) func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config) code_context = func_optimizer.get_code_optimization_context().unwrap() - assert code_context.testgen_context_code.rstrip() == get_code_output.rstrip() + assert code_context.testgen_context.flat.rstrip() == get_code_output.rstrip() def test_code_replacement11() -> None: diff --git a/tests/test_code_utils.py b/tests/test_code_utils.py index 1fa2f95fe..c7abbae7b 100644 --- a/tests/test_code_utils.py +++ b/tests/test_code_utils.py @@ -21,6 +21,7 @@ ) from codeflash.code_utils.concolic_utils import clean_concolic_tests from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files +from codeflash.models.models import CodeStringsMarkdown @pytest.fixture @@ -382,69 +383,76 @@ def mock_code_context(): def test_extract_dependent_function_sync_and_async(mock_code_context): """Test extract_dependent_function with both sync and async functions.""" # Test sync function extraction - mock_code_context.testgen_context_code = """ + mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py def main_function(): pass def helper_function(): pass -""" +``` +""") assert extract_dependent_function("main_function", mock_code_context) == "helper_function" # Test async function extraction - mock_code_context.testgen_context_code = """ + mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py def main_function(): pass async def async_helper_function(): pass -""" +``` +""") + assert extract_dependent_function("main_function", mock_code_context) == "async_helper_function" def test_extract_dependent_function_edge_cases(mock_code_context): """Test extract_dependent_function edge cases.""" # No dependent functions - mock_code_context.testgen_context_code = """ + mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py def main_function(): pass -""" +``` +""") assert extract_dependent_function("main_function", mock_code_context) is False # Multiple dependent functions - mock_code_context.testgen_context_code = """ + mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py def main_function(): pass - def helper1(): pass async def helper2(): pass -""" +``` +""") assert extract_dependent_function("main_function", mock_code_context) is False def test_extract_dependent_function_mixed_scenarios(mock_code_context): """Test extract_dependent_function with mixed sync/async scenarios.""" # Async main with sync helper - mock_code_context.testgen_context_code = """ + mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py async def async_main(): pass def sync_helper(): pass -""" +``` +""") assert extract_dependent_function("async_main", mock_code_context) == "sync_helper" # Only async functions - mock_code_context.testgen_context_code = """ + mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py async def async_main(): pass async def async_helper(): pass -""" +``` +""") + assert extract_dependent_function("async_main", mock_code_context) == "async_helper" diff --git a/tests/test_function_dependencies.py b/tests/test_function_dependencies.py index 019cc4261..4a886ba8d 100644 --- a/tests/test_function_dependencies.py +++ b/tests/test_function_dependencies.py @@ -160,8 +160,9 @@ def test_class_method_dependencies() -> None: ) assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil" assert ( - code_context.testgen_context_code - == """from collections import defaultdict + code_context.testgen_context.flat + == """# file: test_function_dependencies.py +from collections import defaultdict class Graph: def __init__(self, vertices): @@ -220,8 +221,9 @@ def test_recursive_function_context() -> None: assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3" assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive" assert ( - code_context.testgen_context_code - == """class C: + code_context.testgen_context.flat + == """# file: test_function_dependencies.py +class C: def calculate_something_3(self, num): return num + 1 diff --git a/tests/test_get_helper_code.py b/tests/test_get_helper_code.py index a6c300312..7ea5056dd 100644 --- a/tests/test_get_helper_code.py +++ b/tests/test_get_helper_code.py @@ -6,7 +6,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import is_successful -from codeflash.models.models import FunctionParent +from codeflash.models.models import FunctionParent, get_code_block_splitter from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.optimization.optimizer import Optimizer from codeflash.verification.verification_utils import TestConfig @@ -242,8 +242,9 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: code_context = ctx_result.unwrap() assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call" assert ( - code_context.testgen_context_code - == f'''_P = ParamSpec("_P") + code_context.testgen_context.flat + == f'''# file: {file_path.relative_to(project_root_path)} +_P = ParamSpec("_P") _KEY_T = TypeVar("_KEY_T") _STORE_T = TypeVar("_STORE_T") class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): @@ -395,10 +396,11 @@ def test_bubble_sort_deps() -> None: function_to_optimize = FunctionToOptimize( function_name="sorter_deps", file_path=file_path, parents=[], starting_line=None, ending_line=None ) + project_root = file_path.parent.parent.resolve() test_config = TestConfig( tests_root=str(file_path.parent / "tests"), tests_project_rootdir=file_path.parent.resolve(), - project_root_path=file_path.parent.parent.resolve(), + project_root_path=project_root, test_framework="pytest", pytest_cmd="pytest", ) @@ -410,19 +412,20 @@ def test_bubble_sort_deps() -> None: pytest.fail() code_context = ctx_result.unwrap() assert ( - code_context.testgen_context_code - == """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer -from code_to_optimize.bubble_sort_dep2_swap import dep2_swap - + code_context.testgen_context.flat + == f"""{get_code_block_splitter(Path("code_to_optimize/bubble_sort_dep1_helper.py"))} def dep1_comparer(arr, j: int) -> bool: return arr[j] > arr[j + 1] +{get_code_block_splitter(Path("code_to_optimize/bubble_sort_dep2_swap.py"))} def dep2_swap(arr, j): temp = arr[j] arr[j] = arr[j + 1] arr[j + 1] = temp - +{get_code_block_splitter(Path("code_to_optimize/bubble_sort_deps.py"))} +from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer +from code_to_optimize.bubble_sort_dep2_swap import dep2_swap def sorter_deps(arr): for i in range(len(arr)):