From 5981d75b3e2743e77f70534de49a3c3c9e1205bb Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 25 Sep 2025 03:31:05 +0300 Subject: [PATCH 1/4] use markdown context for the testgen --- codeflash/code_utils/coverage_utils.py | 7 ++++--- codeflash/context/code_context_extractor.py | 18 +++++++++--------- codeflash/models/models.py | 2 +- codeflash/optimization/function_optimizer.py | 9 ++++----- tests/test_code_replacement.py | 2 +- tests/test_function_dependencies.py | 4 ++-- tests/test_get_helper_code.py | 4 ++-- 7 files changed, 23 insertions(+), 23 deletions(-) diff --git a/codeflash/code_utils/coverage_utils.py b/codeflash/code_utils/coverage_utils.py index 966910630..60ebc41fb 100644 --- a/codeflash/code_utils/coverage_utils.py +++ b/codeflash/code_utils/coverage_utils.py @@ -12,9 +12,10 @@ def extract_dependent_function(main_function: str, code_context: CodeOptimizationContext) -> str | Literal[False]: """Extract the single dependent function from the code context excluding the main function.""" - ast_tree = ast.parse(code_context.testgen_context_code) - - dependent_functions = {node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)} + dependent_functions = set() + for code_string in code_context.testgen_context.code_strings: + ast_tree = ast.parse(code_string.code) + dependent_functions.update({node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)}) if main_function in dependent_functions: dependent_functions.discard(main_function) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 09c0c564a..affd70a13 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -114,32 +114,32 @@ def get_code_optimization_context( read_only_context_code = "" # Extract code context for testgen - testgen_code_markdown = extract_code_string_context_from_files( + testgen_context = extract_code_markdown_context_from_files( helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.TESTGEN, ) - testgen_context_code = testgen_code_markdown.code - testgen_context_code_tokens = encoded_tokens_len(testgen_context_code) - if testgen_context_code_tokens > testgen_token_limit: - testgen_code_markdown = extract_code_string_context_from_files( + testgen_markdown_code = testgen_context.markdown + testgen_code_token_length = encoded_tokens_len(testgen_markdown_code) + if testgen_code_token_length > testgen_token_limit: + testgen_context = extract_code_markdown_context_from_files( helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True, code_context_type=CodeContextType.TESTGEN, ) - testgen_context_code = testgen_code_markdown.code - testgen_context_code_tokens = encoded_tokens_len(testgen_context_code) - if testgen_context_code_tokens > testgen_token_limit: + testgen_markdown_code = testgen_context.markdown + testgen_code_token_length = encoded_tokens_len(testgen_markdown_code) + if testgen_code_token_length > testgen_token_limit: raise ValueError("Testgen code context has exceeded token limit, cannot proceed") code_hash_context = hashing_code_context.markdown code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest() return CodeOptimizationContext( - testgen_context_code=testgen_context_code, + testgen_context=testgen_context, read_writable_code=final_read_writable_code, read_only_context_code=read_only_context_code, hashing_code_context=code_hash_context, diff --git a/codeflash/models/models.py b/codeflash/models/models.py index c1a563672..b68b7ce24 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -253,7 +253,7 @@ def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown: class CodeOptimizationContext(BaseModel): - testgen_context_code: str = "" + testgen_context: CodeStringsMarkdown read_writable_code: CodeStringsMarkdown read_only_context_code: str = "" hashing_code_context: str = "" diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 926adff95..d66bafeab 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -309,7 +309,7 @@ def generate_and_instrument_tests( revert_to_print=bool(get_pr_number()), ): generated_results = self.generate_tests_and_optimizations( - testgen_context_code=code_context.testgen_context_code, + testgen_context=code_context.testgen_context, read_writable_code=code_context.read_writable_code, read_only_context_code=code_context.read_only_context_code, helper_functions=code_context.helper_functions, @@ -345,7 +345,6 @@ def generate_and_instrument_tests( logger.info(f"Generated test {i + 1}/{count_tests}:") code_print(generated_test.generated_original_test_source, file_name=f"test_{i + 1}.py") if concolic_test_str: - # no concolic tests in lsp mode logger.info(f"Generated test {count_tests}/{count_tests}:") code_print(concolic_test_str) @@ -946,7 +945,7 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: return Success( CodeOptimizationContext( - testgen_context_code=new_code_ctx.testgen_context_code, + testgen_context=new_code_ctx.testgen_context, read_writable_code=new_code_ctx.read_writable_code, read_only_context_code=new_code_ctx.read_only_context_code, hashing_code_context=new_code_ctx.hashing_code_context, @@ -1053,7 +1052,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio def generate_tests_and_optimizations( self, - testgen_context_code: str, + testgen_context: CodeStringsMarkdown, read_writable_code: CodeStringsMarkdown, read_only_context_code: str, helper_functions: list[FunctionSource], @@ -1067,7 +1066,7 @@ def generate_tests_and_optimizations( # Submit the test generation task as future future_tests = self.submit_test_generation_tasks( self.executor, - testgen_context_code, + testgen_context.markdown, [definition.fully_qualified_name for definition in helper_functions], generated_test_paths, generated_perf_test_paths, diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index f7bfaace3..32de8bc4d 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -827,7 +827,7 @@ def main_method(self): ) func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config) code_context = func_optimizer.get_code_optimization_context().unwrap() - assert code_context.testgen_context_code.rstrip() == get_code_output.rstrip() + assert code_context.testgen_context.rstrip() == get_code_output.rstrip() def test_code_replacement11() -> None: diff --git a/tests/test_function_dependencies.py b/tests/test_function_dependencies.py index 019cc4261..49f4fc30e 100644 --- a/tests/test_function_dependencies.py +++ b/tests/test_function_dependencies.py @@ -160,7 +160,7 @@ def test_class_method_dependencies() -> None: ) assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil" assert ( - code_context.testgen_context_code + code_context.testgen_context == """from collections import defaultdict class Graph: @@ -220,7 +220,7 @@ def test_recursive_function_context() -> None: assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3" assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive" assert ( - code_context.testgen_context_code + code_context.testgen_context == """class C: def calculate_something_3(self, num): return num + 1 diff --git a/tests/test_get_helper_code.py b/tests/test_get_helper_code.py index 36359d3e3..c3382e513 100644 --- a/tests/test_get_helper_code.py +++ b/tests/test_get_helper_code.py @@ -241,7 +241,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: code_context = ctx_result.unwrap() assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call" assert ( - code_context.testgen_context_code + code_context.testgen_context == f'''_P = ParamSpec("_P") _KEY_T = TypeVar("_KEY_T") _STORE_T = TypeVar("_STORE_T") @@ -409,7 +409,7 @@ def test_bubble_sort_deps() -> None: pytest.fail() code_context = ctx_result.unwrap() assert ( - code_context.testgen_context_code + code_context.testgen_context == """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer from code_to_optimize.bubble_sort_dep2_swap import dep2_swap From 4b41ab73bf341f478fded598a17bc91e81e01ad9 Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 25 Sep 2025 04:10:25 +0300 Subject: [PATCH 2/4] fix: unit tests --- tests/test_code_replacement.py | 5 +++-- tests/test_function_dependencies.py | 10 ++++++---- tests/test_get_helper_code.py | 19 +++++++++++-------- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 32de8bc4d..2d547e27e 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -797,7 +797,8 @@ def main_method(self): def test_code_replacement10() -> None: - get_code_output = """from __future__ import annotations + get_code_output = """# file: test_code_replacement.py +from __future__ import annotations class HelperClass: def __init__(self, name): @@ -827,7 +828,7 @@ def main_method(self): ) func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config) code_context = func_optimizer.get_code_optimization_context().unwrap() - assert code_context.testgen_context.rstrip() == get_code_output.rstrip() + assert code_context.testgen_context.flat.rstrip() == get_code_output.rstrip() def test_code_replacement11() -> None: diff --git a/tests/test_function_dependencies.py b/tests/test_function_dependencies.py index 49f4fc30e..4a886ba8d 100644 --- a/tests/test_function_dependencies.py +++ b/tests/test_function_dependencies.py @@ -160,8 +160,9 @@ def test_class_method_dependencies() -> None: ) assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil" assert ( - code_context.testgen_context - == """from collections import defaultdict + code_context.testgen_context.flat + == """# file: test_function_dependencies.py +from collections import defaultdict class Graph: def __init__(self, vertices): @@ -220,8 +221,9 @@ def test_recursive_function_context() -> None: assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3" assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive" assert ( - code_context.testgen_context - == """class C: + code_context.testgen_context.flat + == """# file: test_function_dependencies.py +class C: def calculate_something_3(self, num): return num + 1 diff --git a/tests/test_get_helper_code.py b/tests/test_get_helper_code.py index c3382e513..5cf2c963e 100644 --- a/tests/test_get_helper_code.py +++ b/tests/test_get_helper_code.py @@ -241,8 +241,9 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: code_context = ctx_result.unwrap() assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call" assert ( - code_context.testgen_context - == f'''_P = ParamSpec("_P") + code_context.testgen_context.flat + == f'''# file: {file_path.relative_to(project_root_path)} +_P = ParamSpec("_P") _KEY_T = TypeVar("_KEY_T") _STORE_T = TypeVar("_STORE_T") class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): @@ -394,10 +395,11 @@ def test_bubble_sort_deps() -> None: function_to_optimize = FunctionToOptimize( function_name="sorter_deps", file_path=file_path, parents=[], starting_line=None, ending_line=None ) + project_root = file_path.parent.parent.resolve() test_config = TestConfig( tests_root=str(file_path.parent / "tests"), tests_project_rootdir=file_path.parent.resolve(), - project_root_path=file_path.parent.parent.resolve(), + project_root_path=project_root, test_framework="pytest", pytest_cmd="pytest", ) @@ -409,19 +411,20 @@ def test_bubble_sort_deps() -> None: pytest.fail() code_context = ctx_result.unwrap() assert ( - code_context.testgen_context - == """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer -from code_to_optimize.bubble_sort_dep2_swap import dep2_swap - + code_context.testgen_context.flat + == f"""# file: code_to_optimize/bubble_sort_dep1_helper.py def dep1_comparer(arr, j: int) -> bool: return arr[j] > arr[j + 1] +# file: code_to_optimize/bubble_sort_dep2_swap.py def dep2_swap(arr, j): temp = arr[j] arr[j] = arr[j + 1] arr[j + 1] = temp - +# file: code_to_optimize/bubble_sort_deps.py +from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer +from code_to_optimize.bubble_sort_dep2_swap import dep2_swap def sorter_deps(arr): for i in range(len(arr)): From c7941e9060cd0fc5e092ba979ff96caaf2a64733 Mon Sep 17 00:00:00 2001 From: ali Date: Wed, 8 Oct 2025 02:15:29 +0300 Subject: [PATCH 3/4] generate the mock test context correctly --- tests/test_code_utils.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/tests/test_code_utils.py b/tests/test_code_utils.py index 1fa2f95fe..c7abbae7b 100644 --- a/tests/test_code_utils.py +++ b/tests/test_code_utils.py @@ -21,6 +21,7 @@ ) from codeflash.code_utils.concolic_utils import clean_concolic_tests from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files +from codeflash.models.models import CodeStringsMarkdown @pytest.fixture @@ -382,69 +383,76 @@ def mock_code_context(): def test_extract_dependent_function_sync_and_async(mock_code_context): """Test extract_dependent_function with both sync and async functions.""" # Test sync function extraction - mock_code_context.testgen_context_code = """ + mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py def main_function(): pass def helper_function(): pass -""" +``` +""") assert extract_dependent_function("main_function", mock_code_context) == "helper_function" # Test async function extraction - mock_code_context.testgen_context_code = """ + mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py def main_function(): pass async def async_helper_function(): pass -""" +``` +""") + assert extract_dependent_function("main_function", mock_code_context) == "async_helper_function" def test_extract_dependent_function_edge_cases(mock_code_context): """Test extract_dependent_function edge cases.""" # No dependent functions - mock_code_context.testgen_context_code = """ + mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py def main_function(): pass -""" +``` +""") assert extract_dependent_function("main_function", mock_code_context) is False # Multiple dependent functions - mock_code_context.testgen_context_code = """ + mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py def main_function(): pass - def helper1(): pass async def helper2(): pass -""" +``` +""") assert extract_dependent_function("main_function", mock_code_context) is False def test_extract_dependent_function_mixed_scenarios(mock_code_context): """Test extract_dependent_function with mixed sync/async scenarios.""" # Async main with sync helper - mock_code_context.testgen_context_code = """ + mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py async def async_main(): pass def sync_helper(): pass -""" +``` +""") assert extract_dependent_function("async_main", mock_code_context) == "sync_helper" # Only async functions - mock_code_context.testgen_context_code = """ + mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py async def async_main(): pass async def async_helper(): pass -""" +``` +""") + assert extract_dependent_function("async_main", mock_code_context) == "async_helper" From d863a0d7dc0df0b69e7dd7513f5e66a3a4a1be49 Mon Sep 17 00:00:00 2001 From: ali Date: Wed, 8 Oct 2025 02:32:03 +0300 Subject: [PATCH 4/4] fix test for windows --- codeflash/models/models.py | 2 +- tests/test_get_helper_code.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 8f43f09d5..84179054e 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -163,7 +163,7 @@ class CodeString(BaseModel): def get_code_block_splitter(file_path: Path) -> str: - return f"# file: {file_path}" + return f"# file: {file_path.as_posix()}" markdown_pattern = re.compile(r"```python:([^\n]+)\n(.*?)\n```", re.DOTALL) diff --git a/tests/test_get_helper_code.py b/tests/test_get_helper_code.py index 26260ff4e..7ea5056dd 100644 --- a/tests/test_get_helper_code.py +++ b/tests/test_get_helper_code.py @@ -6,7 +6,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import is_successful -from codeflash.models.models import FunctionParent +from codeflash.models.models import FunctionParent, get_code_block_splitter from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.optimization.optimizer import Optimizer from codeflash.verification.verification_utils import TestConfig @@ -413,17 +413,17 @@ def test_bubble_sort_deps() -> None: code_context = ctx_result.unwrap() assert ( code_context.testgen_context.flat - == f"""# file: code_to_optimize/bubble_sort_dep1_helper.py + == f"""{get_code_block_splitter(Path("code_to_optimize/bubble_sort_dep1_helper.py"))} def dep1_comparer(arr, j: int) -> bool: return arr[j] > arr[j + 1] -# file: code_to_optimize/bubble_sort_dep2_swap.py +{get_code_block_splitter(Path("code_to_optimize/bubble_sort_dep2_swap.py"))} def dep2_swap(arr, j): temp = arr[j] arr[j] = arr[j + 1] arr[j + 1] = temp -# file: code_to_optimize/bubble_sort_deps.py +{get_code_block_splitter(Path("code_to_optimize/bubble_sort_deps.py"))} from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer from code_to_optimize.bubble_sort_dep2_swap import dep2_swap