diff --git a/codeflash/code_utils/coverage_utils.py b/codeflash/code_utils/coverage_utils.py index 966910630..11df7687e 100644 --- a/codeflash/code_utils/coverage_utils.py +++ b/codeflash/code_utils/coverage_utils.py @@ -14,7 +14,9 @@ def extract_dependent_function(main_function: str, code_context: CodeOptimizatio """Extract the single dependent function from the code context excluding the main function.""" ast_tree = ast.parse(code_context.testgen_context_code) - dependent_functions = {node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)} + dependent_functions = { + node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + } if main_function in dependent_functions: dependent_functions.discard(main_function) diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index cf57af031..8d3a623b1 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -642,7 +642,10 @@ def detect_unused_helper_functions( # Find the optimized entrypoint function entrypoint_function_ast = None for node in ast.walk(optimized_ast): - if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name: + if ( + isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and node.name == function_to_optimize.function_name + ): entrypoint_function_ast = node break diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index d82a4728b..b6b2aa219 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -12,7 +12,7 @@ from codeflash.models.models import FunctionParent from codeflash.optimization.optimizer import Optimizer from codeflash.code_utils.code_replacer import replace_functions_and_add_imports -from codeflash.code_utils.code_extractor import add_global_assignments +from codeflash.code_utils.code_extractor import add_global_assignments, GlobalAssignmentCollector class HelperClass: @@ -2482,3 +2482,148 @@ def test_circular_deps(): assert "import ApiClient" not in new_code, "Error: Circular dependency found" assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist" +def test_global_assignment_collector_with_async_function(): + """Test GlobalAssignmentCollector correctly identifies global assignments outside async functions.""" + import libcst as cst + + source_code = """ +# Global assignment +GLOBAL_VAR = "global_value" +OTHER_GLOBAL = 42 + +async def async_function(): + # This should not be collected (inside async function) + local_var = "local_value" + INNER_ASSIGNMENT = "should_not_be_global" + return local_var + +# Another global assignment +ANOTHER_GLOBAL = "another_global" +""" + + tree = cst.parse_module(source_code) + collector = GlobalAssignmentCollector() + tree.visit(collector) + + # Should collect global assignments but not the ones inside async function + assert len(collector.assignments) == 3 + assert "GLOBAL_VAR" in collector.assignments + assert "OTHER_GLOBAL" in collector.assignments + assert "ANOTHER_GLOBAL" in collector.assignments + + # Should not collect assignments from inside async function + assert "local_var" not in collector.assignments + assert "INNER_ASSIGNMENT" not in collector.assignments + + # Verify assignment order + expected_order = ["GLOBAL_VAR", "OTHER_GLOBAL", "ANOTHER_GLOBAL"] + assert collector.assignment_order == expected_order + + +def test_global_assignment_collector_nested_async_functions(): + """Test GlobalAssignmentCollector handles nested async functions correctly.""" + import libcst as cst + + source_code = """ +# Global assignment +CONFIG = {"key": "value"} + +def sync_function(): + # Inside sync function - should not be collected + sync_local = "sync" + + async def nested_async(): + # Inside nested async function - should not be collected + nested_var = "nested" + return nested_var + + return sync_local + +async def async_function(): + # Inside async function - should not be collected + async_local = "async" + + def nested_sync(): + # Inside nested function - should not be collected + deeply_nested = "deep" + return deeply_nested + + return async_local + +# Another global assignment +FINAL_GLOBAL = "final" +""" + + tree = cst.parse_module(source_code) + collector = GlobalAssignmentCollector() + tree.visit(collector) + + # Should only collect global-level assignments + assert len(collector.assignments) == 2 + assert "CONFIG" in collector.assignments + assert "FINAL_GLOBAL" in collector.assignments + + # Should not collect any assignments from inside functions + assert "sync_local" not in collector.assignments + assert "nested_var" not in collector.assignments + assert "async_local" not in collector.assignments + assert "deeply_nested" not in collector.assignments + + +def test_global_assignment_collector_mixed_async_sync_with_classes(): + """Test GlobalAssignmentCollector with async functions, sync functions, and classes.""" + import libcst as cst + + source_code = """ +# Global assignments +GLOBAL_CONSTANT = "constant" + +class TestClass: + # Class-level assignment - should not be collected + class_var = "class_value" + + def sync_method(self): + # Method assignment - should not be collected + method_var = "method" + return method_var + + async def async_method(self): + # Async method assignment - should not be collected + async_method_var = "async_method" + return async_method_var + +def sync_function(): + # Function assignment - should not be collected + func_var = "function" + return func_var + +async def async_function(): + # Async function assignment - should not be collected + async_func_var = "async_function" + return async_func_var + +# More global assignments +ANOTHER_CONSTANT = 100 +FINAL_ASSIGNMENT = {"data": "value"} +""" + + tree = cst.parse_module(source_code) + collector = GlobalAssignmentCollector() + tree.visit(collector) + + # Should only collect global-level assignments + assert len(collector.assignments) == 3 + assert "GLOBAL_CONSTANT" in collector.assignments + assert "ANOTHER_CONSTANT" in collector.assignments + assert "FINAL_ASSIGNMENT" in collector.assignments + + # Should not collect assignments from inside any scoped blocks + assert "class_var" not in collector.assignments + assert "method_var" not in collector.assignments + assert "async_method_var" not in collector.assignments + assert "func_var" not in collector.assignments + assert "async_func_var" not in collector.assignments + + # Verify correct order + expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"] + assert collector.assignment_order == expected_order diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 405896087..e760237fa 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -12,6 +12,7 @@ is_zero_diff, replace_functions_and_add_imports, replace_functions_in_file, + OptimFunctionCollector, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent @@ -3453,3 +3454,137 @@ def hydrate_input_text_actions_with_field_names( main_file.unlink(missing_ok=True) assert new_code == expected + + +# OptimFunctionCollector async function tests +def test_optim_function_collector_with_async_functions(): + """Test OptimFunctionCollector correctly collects async functions.""" + import libcst as cst + + source_code = """ +def sync_function(): + return "sync" + +async def async_function(): + return "async" + +class TestClass: + def sync_method(self): + return "sync_method" + + async def async_method(self): + return "async_method" +""" + + tree = cst.parse_module(source_code) + collector = OptimFunctionCollector( + function_names={(None, "sync_function"), (None, "async_function"), ("TestClass", "sync_method"), ("TestClass", "async_method")}, + preexisting_objects=None + ) + tree.visit(collector) + + # Should collect both sync and async functions + assert len(collector.modified_functions) == 4 + assert (None, "sync_function") in collector.modified_functions + assert (None, "async_function") in collector.modified_functions + assert ("TestClass", "sync_method") in collector.modified_functions + assert ("TestClass", "async_method") in collector.modified_functions + + +def test_optim_function_collector_new_async_functions(): + """Test OptimFunctionCollector identifies new async functions not in preexisting objects.""" + import libcst as cst + + source_code = """ +def existing_function(): + return "existing" + +async def new_async_function(): + return "new_async" + +def new_sync_function(): + return "new_sync" + +class ExistingClass: + async def new_class_async_method(self): + return "new_class_async" +""" + + # Only existing_function is in preexisting objects + preexisting_objects = {("existing_function", ())} + + tree = cst.parse_module(source_code) + collector = OptimFunctionCollector( + function_names=set(), # Not looking for specific functions + preexisting_objects=preexisting_objects + ) + tree.visit(collector) + + # Should identify new functions (both sync and async) + assert len(collector.new_functions) == 2 + function_names = [func.name.value for func in collector.new_functions] + assert "new_async_function" in function_names + assert "new_sync_function" in function_names + + # Should identify new class methods + assert "ExistingClass" in collector.new_class_functions + assert len(collector.new_class_functions["ExistingClass"]) == 1 + assert collector.new_class_functions["ExistingClass"][0].name.value == "new_class_async_method" + + +def test_optim_function_collector_mixed_scenarios(): + """Test OptimFunctionCollector with complex mix of sync/async functions and classes.""" + import libcst as cst + + source_code = """ +# Global functions +def global_sync(): + pass + +async def global_async(): + pass + +class ParentClass: + def __init__(self): + pass + + def sync_method(self): + pass + + async def async_method(self): + pass + +class ChildClass: + async def child_async_method(self): + pass + + def child_sync_method(self): + pass +""" + + # Looking for specific functions + function_names = { + (None, "global_sync"), + (None, "global_async"), + ("ParentClass", "sync_method"), + ("ParentClass", "async_method"), + ("ChildClass", "child_async_method") + } + + tree = cst.parse_module(source_code) + collector = OptimFunctionCollector( + function_names=function_names, + preexisting_objects=None + ) + tree.visit(collector) + + # Should collect all specified functions (mix of sync and async) + assert len(collector.modified_functions) == 5 + assert (None, "global_sync") in collector.modified_functions + assert (None, "global_async") in collector.modified_functions + assert ("ParentClass", "sync_method") in collector.modified_functions + assert ("ParentClass", "async_method") in collector.modified_functions + assert ("ChildClass", "child_async_method") in collector.modified_functions + + # Should collect __init__ method + assert "ParentClass" in collector.modified_init_functions diff --git a/tests/test_code_utils.py b/tests/test_code_utils.py index 4fc28bea2..b17fd0758 100644 --- a/tests/test_code_utils.py +++ b/tests/test_code_utils.py @@ -20,7 +20,7 @@ validate_python_code, ) from codeflash.code_utils.concolic_utils import clean_concolic_tests -from codeflash.code_utils.coverage_utils import generate_candidates, prepare_coverage_files +from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files @pytest.fixture @@ -308,6 +308,86 @@ def my_function(): assert is_class_defined_in_file("MyClass", test_file) is False +@pytest.fixture +def mock_code_context(): + """Mock CodeOptimizationContext for testing extract_dependent_function.""" + from unittest.mock import MagicMock + from codeflash.models.models import CodeOptimizationContext + + context = MagicMock(spec=CodeOptimizationContext) + context.preexisting_objects = [] + return context + + +def test_extract_dependent_function_sync_and_async(mock_code_context): + """Test extract_dependent_function with both sync and async functions.""" + # Test sync function extraction + mock_code_context.testgen_context_code = """ +def main_function(): + pass + +def helper_function(): + pass +""" + assert extract_dependent_function("main_function", mock_code_context) == "helper_function" + + # Test async function extraction + mock_code_context.testgen_context_code = """ +def main_function(): + pass + +async def async_helper_function(): + pass +""" + assert extract_dependent_function("main_function", mock_code_context) == "async_helper_function" + + +def test_extract_dependent_function_edge_cases(mock_code_context): + """Test extract_dependent_function edge cases.""" + # No dependent functions + mock_code_context.testgen_context_code = """ +def main_function(): + pass +""" + assert extract_dependent_function("main_function", mock_code_context) is False + + # Multiple dependent functions + mock_code_context.testgen_context_code = """ +def main_function(): + pass + +def helper1(): + pass + +async def helper2(): + pass +""" + assert extract_dependent_function("main_function", mock_code_context) is False + + +def test_extract_dependent_function_mixed_scenarios(mock_code_context): + """Test extract_dependent_function with mixed sync/async scenarios.""" + # Async main with sync helper + mock_code_context.testgen_context_code = """ +async def async_main(): + pass + +def sync_helper(): + pass +""" + assert extract_dependent_function("async_main", mock_code_context) == "sync_helper" + + # Only async functions + mock_code_context.testgen_context_code = """ +async def async_main(): + pass + +async def async_helper(): + pass +""" + assert extract_dependent_function("async_main", mock_code_context) == "async_helper" + + def test_is_class_defined_in_file_with_non_existing_file() -> None: non_existing_file = Path("/non/existing/file.py") diff --git a/tests/test_unused_helper_revert.py b/tests/test_unused_helper_revert.py index 30f291e62..25909a688 100644 --- a/tests/test_unused_helper_revert.py +++ b/tests/test_unused_helper_revert.py @@ -9,6 +9,8 @@ from codeflash.models.models import CodeStringsMarkdown from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig +from codeflash.context.unused_definition_remover import revert_unused_helper_functions + @pytest.fixture @@ -225,6 +227,15 @@ def helper_function_2(x): def test_no_unused_helpers_no_revert(temp_project): """Test that when all helpers are still used, nothing is reverted.""" temp_dir, main_file, test_cfg = temp_project + + + # Store original content to verify nothing changes + original_content = main_file.read_text() + + revert_unused_helper_functions(temp_dir, [], {}) + + # Verify the file content remains unchanged + assert main_file.read_text() == original_content, "File should remain unchanged when no helpers to revert" # Optimized version that still calls both helpers optimized_code = """ @@ -308,17 +319,23 @@ def helper_function_1(x): def helper_function_2(x): \"\"\"Second helper function.\"\"\" return x * 3 + +def helper_function_1(y): # Duplicate name to test line 575 + \"\"\"Overloaded helper function.\"\"\" + return y + 10 """) - # Optimized version that only calls one helper + # Optimized version that only calls one helper with aliased import optimized_code = """ ```python:main.py -from helpers import helper_function_1 +from helpers import helper_function_1 as h1 +import helpers as h_module def entrypoint_function(n): - \"\"\"Optimized function that only calls one helper.\"\"\" - result1 = helper_function_1(n) - return result1 + n * 3 # Inlined helper_function_2 + \"\"\"Optimized function that only calls one helper with aliasing.\"\"\" + result1 = h1(n) # Using aliased import + # Inlined helper_function_2 functionality: n * 3 + return result1 + n * 3 # Fully inlined helper_function_2 ``` """ @@ -1460,3 +1477,595 @@ def calculate_class(cls, n): import shutil shutil.rmtree(temp_dir, ignore_errors=True) + + +def test_async_entrypoint_with_async_helpers(): + """Test that unused async helper functions are correctly detected when entrypoint is async.""" + temp_dir = Path(tempfile.mkdtemp()) + + try: + # Main file with async entrypoint and async helpers + main_file = temp_dir / "main.py" + main_file.write_text(""" +async def async_helper_1(x): + \"\"\"First async helper function.\"\"\" + return x * 2 + +async def async_helper_2(x): + \"\"\"Second async helper function.\"\"\" + return x * 3 + +async def async_entrypoint(n): + \"\"\"Async entrypoint function that calls async helpers.\"\"\" + result1 = await async_helper_1(n) + result2 = await async_helper_2(n) + return result1 + result2 +""") + + # Optimized version that only calls one async helper + optimized_code = """ +```python:main.py +async def async_helper_1(x): + \"\"\"First async helper function.\"\"\" + return x * 2 + +async def async_helper_2(x): + \"\"\"Second async helper function - should be unused.\"\"\" + return x * 3 + +async def async_entrypoint(n): + \"\"\"Optimized async entrypoint that only calls one helper.\"\"\" + result1 = await async_helper_1(n) + return result1 + n * 3 # Inlined async_helper_2 +``` +""" + + # Create test config + test_cfg = TestConfig( + tests_root=temp_dir / "tests", + tests_project_rootdir=temp_dir, + project_root_path=temp_dir, + test_framework="pytest", + pytest_cmd="pytest", + ) + + # Create FunctionToOptimize instance for async function + function_to_optimize = FunctionToOptimize( + file_path=main_file, + function_name="async_entrypoint", + parents=[], + is_async=True + ) + + # Create function optimizer + optimizer = FunctionOptimizer( + function_to_optimize=function_to_optimize, + test_cfg=test_cfg, + function_to_optimize_source_code=main_file.read_text(), + ) + + # Get original code context + ctx_result = optimizer.get_code_optimization_context() + assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}" + + code_context = ctx_result.unwrap() + + # Test unused helper detection + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) + + # Should detect async_helper_2 as unused + unused_names = {uh.qualified_name for uh in unused_helpers} + expected_unused = {"async_helper_2"} + + assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}" + + finally: + # Cleanup + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + +def test_sync_entrypoint_with_async_helpers(): + """Test that unused async helper functions are detected when entrypoint is sync.""" + temp_dir = Path(tempfile.mkdtemp()) + + try: + # Main file with sync entrypoint and async helpers + main_file = temp_dir / "main.py" + main_file.write_text(""" +import asyncio + +async def async_helper_1(x): + \"\"\"First async helper function.\"\"\" + return x * 2 + +async def async_helper_2(x): + \"\"\"Second async helper function.\"\"\" + return x * 3 + +def sync_entrypoint(n): + \"\"\"Sync entrypoint function that calls async helpers.\"\"\" + result1 = asyncio.run(async_helper_1(n)) + result2 = asyncio.run(async_helper_2(n)) + return result1 + result2 +""") + + # Optimized version that only calls one async helper + optimized_code = """ +```python:main.py +import asyncio + +async def async_helper_1(x): + \"\"\"First async helper function.\"\"\" + return x * 2 + +async def async_helper_2(x): + \"\"\"Second async helper function - should be unused.\"\"\" + return x * 3 + +def sync_entrypoint(n): + \"\"\"Optimized sync entrypoint that only calls one async helper.\"\"\" + result1 = asyncio.run(async_helper_1(n)) + return result1 + n * 3 # Inlined async_helper_2 +``` +""" + + # Create test config + test_cfg = TestConfig( + tests_root=temp_dir / "tests", + tests_project_rootdir=temp_dir, + project_root_path=temp_dir, + test_framework="pytest", + pytest_cmd="pytest", + ) + + # Create FunctionToOptimize instance for sync function + function_to_optimize = FunctionToOptimize( + file_path=main_file, + function_name="sync_entrypoint", + parents=[] + ) + + # Create function optimizer + optimizer = FunctionOptimizer( + function_to_optimize=function_to_optimize, + test_cfg=test_cfg, + function_to_optimize_source_code=main_file.read_text(), + ) + + # Get original code context + ctx_result = optimizer.get_code_optimization_context() + assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}" + + code_context = ctx_result.unwrap() + + # Test unused helper detection + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) + + # Should detect async_helper_2 as unused + unused_names = {uh.qualified_name for uh in unused_helpers} + expected_unused = {"async_helper_2"} + + assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}" + + finally: + # Cleanup + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + +def test_mixed_sync_and_async_helpers(): + """Test detection when both sync and async helpers are mixed.""" + temp_dir = Path(tempfile.mkdtemp()) + + try: + # Main file with mixed sync and async helpers + main_file = temp_dir / "main.py" + main_file.write_text(""" +import asyncio + +def sync_helper_1(x): + \"\"\"Sync helper function.\"\"\" + return x * 2 + +async def async_helper_1(x): + \"\"\"Async helper function.\"\"\" + return x * 3 + +def sync_helper_2(x): + \"\"\"Another sync helper function.\"\"\" + return x * 4 + +async def async_helper_2(x): + \"\"\"Another async helper function.\"\"\" + return x * 5 + +async def mixed_entrypoint(n): + \"\"\"Async entrypoint function that calls both sync and async helpers.\"\"\" + sync_result = sync_helper_1(n) + async_result = await async_helper_1(n) + sync_result2 = sync_helper_2(n) + async_result2 = await async_helper_2(n) + return sync_result + async_result + sync_result2 + async_result2 +""") + + # Optimized version that only calls some helpers + optimized_code = """ +```python:main.py +import asyncio + +def sync_helper_1(x): + \"\"\"Sync helper function.\"\"\" + return x * 2 + +async def async_helper_1(x): + \"\"\"Async helper function.\"\"\" + return x * 3 + +def sync_helper_2(x): + \"\"\"Another sync helper function - should be unused.\"\"\" + return x * 4 + +async def async_helper_2(x): + \"\"\"Another async helper function - should be unused.\"\"\" + return x * 5 + +async def mixed_entrypoint(n): + \"\"\"Optimized async entrypoint that only calls some helpers.\"\"\" + sync_result = sync_helper_1(n) + async_result = await async_helper_1(n) + return sync_result + async_result + n * 4 + n * 5 # Inlined both helper_2 functions +``` +""" + + # Create test config + test_cfg = TestConfig( + tests_root=temp_dir / "tests", + tests_project_rootdir=temp_dir, + project_root_path=temp_dir, + test_framework="pytest", + pytest_cmd="pytest", + ) + + # Create FunctionToOptimize instance for async function + function_to_optimize = FunctionToOptimize( + file_path=main_file, + function_name="mixed_entrypoint", + parents=[], + is_async=True + ) + + # Create function optimizer + optimizer = FunctionOptimizer( + function_to_optimize=function_to_optimize, + test_cfg=test_cfg, + function_to_optimize_source_code=main_file.read_text(), + ) + + # Get original code context + ctx_result = optimizer.get_code_optimization_context() + assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}" + + code_context = ctx_result.unwrap() + + # Test unused helper detection + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) + + # Should detect both sync_helper_2 and async_helper_2 as unused + unused_names = {uh.qualified_name for uh in unused_helpers} + expected_unused = {"sync_helper_2", "async_helper_2"} + + assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}" + + finally: + # Cleanup + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + +def test_async_class_methods(): + """Test unused async method detection in classes.""" + temp_dir = Path(tempfile.mkdtemp()) + + try: + # Main file with class containing async methods + main_file = temp_dir / "main.py" + main_file.write_text(""" +class AsyncProcessor: + async def entrypoint_method(self, n): + \"\"\"Async main method that calls async helper methods.\"\"\" + result1 = await self.async_helper_method_1(n) + result2 = await self.async_helper_method_2(n) + return result1 + result2 + + async def async_helper_method_1(self, x): + \"\"\"First async helper method.\"\"\" + return x * 2 + + async def async_helper_method_2(self, x): + \"\"\"Second async helper method.\"\"\" + return x * 3 + + def sync_helper_method(self, x): + \"\"\"Sync helper method.\"\"\" + return x * 4 +""") + + # Optimized version that only calls one async helper + optimized_code = """ +```python:main.py +class AsyncProcessor: + async def entrypoint_method(self, n): + \"\"\"Optimized async method that only calls one helper.\"\"\" + result1 = await self.async_helper_method_1(n) + return result1 + n * 3 # Inlined async_helper_method_2 + + async def async_helper_method_1(self, x): + \"\"\"First async helper method.\"\"\" + return x * 2 + + async def async_helper_method_2(self, x): + \"\"\"Second async helper method - should be unused.\"\"\" + return x * 3 + + def sync_helper_method(self, x): + \"\"\"Sync helper method - should be unused.\"\"\" + return x * 4 +``` +""" + + # Create test config + test_cfg = TestConfig( + tests_root=temp_dir / "tests", + tests_project_rootdir=temp_dir, + project_root_path=temp_dir, + test_framework="pytest", + pytest_cmd="pytest", + ) + + # Create FunctionToOptimize instance for async class method + from codeflash.models.models import FunctionParent + + function_to_optimize = FunctionToOptimize( + file_path=main_file, + function_name="entrypoint_method", + parents=[FunctionParent(name="AsyncProcessor", type="ClassDef")], + is_async=True + ) + + # Create function optimizer + optimizer = FunctionOptimizer( + function_to_optimize=function_to_optimize, + test_cfg=test_cfg, + function_to_optimize_source_code=main_file.read_text(), + ) + + # Get original code context + ctx_result = optimizer.get_code_optimization_context() + assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}" + + code_context = ctx_result.unwrap() + + # Test unused helper detection + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) + + # Should detect async_helper_method_2 as unused (sync_helper_method may not be discovered as helper) + unused_names = {uh.qualified_name for uh in unused_helpers} + expected_unused = {"AsyncProcessor.async_helper_method_2"} + + assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}" + + finally: + # Cleanup + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + +def test_async_helper_revert_functionality(): + """Test that unused async helper functions are correctly reverted to original definitions.""" + temp_dir = Path(tempfile.mkdtemp()) + + try: + # Main file with async functions + main_file = temp_dir / "main.py" + main_file.write_text(""" +async def async_helper_1(x): + \"\"\"First async helper function.\"\"\" + return x * 2 + +async def async_helper_2(x): + \"\"\"Second async helper function.\"\"\" + return x * 3 + +async def async_entrypoint(n): + \"\"\"Async entrypoint function that calls async helpers.\"\"\" + result1 = await async_helper_1(n) + result2 = await async_helper_2(n) + return result1 + result2 +""") + + # Optimized version that only calls one helper and modifies the unused one + optimized_code = """ +```python:main.py +async def async_helper_1(x): + \"\"\"First async helper function.\"\"\" + return x * 2 + +async def async_helper_2(x): + \"\"\"Modified async helper function - should be reverted.\"\"\" + return x * 10 # This change should be reverted + +async def async_entrypoint(n): + \"\"\"Optimized async entrypoint that only calls one helper.\"\"\" + result1 = await async_helper_1(n) + return result1 + n * 3 # Inlined async_helper_2 +``` +""" + + # Create test config + test_cfg = TestConfig( + tests_root=temp_dir / "tests", + tests_project_rootdir=temp_dir, + project_root_path=temp_dir, + test_framework="pytest", + pytest_cmd="pytest", + ) + + # Create FunctionToOptimize instance for async function + function_to_optimize = FunctionToOptimize( + file_path=main_file, + function_name="async_entrypoint", + parents=[], + is_async=True + ) + + # Create function optimizer + optimizer = FunctionOptimizer( + function_to_optimize=function_to_optimize, + test_cfg=test_cfg, + function_to_optimize_source_code=main_file.read_text(), + ) + + # Get original code context + ctx_result = optimizer.get_code_optimization_context() + assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}" + + code_context = ctx_result.unwrap() + + # Store original helper code + original_helper_code = {main_file: main_file.read_text()} + + # Apply optimization and test reversion + optimizer.replace_function_and_helpers_with_optimized_code( + code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code + ) + + # Check final file content + final_content = main_file.read_text() + + # The entrypoint should be optimized + assert "result1 + n * 3" in final_content, "Async entrypoint function should be optimized" + + # async_helper_2 should be reverted to original (return x * 3, not x * 10) + assert "return x * 3" in final_content, "async_helper_2 should be reverted to original" + assert "return x * 10" not in final_content, "async_helper_2 should not contain the modified version" + + # async_helper_1 should remain (it's still called) + assert "async def async_helper_1(x):" in final_content, "async_helper_1 should still exist" + + finally: + # Cleanup + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + +def test_async_generators_and_coroutines(): + """Test detection with async generators and coroutines.""" + temp_dir = Path(tempfile.mkdtemp()) + + try: + # Main file with async generators and coroutines + main_file = temp_dir / "main.py" + main_file.write_text(""" +import asyncio + +async def async_generator_helper(n): + \"\"\"Async generator helper.\"\"\" + for i in range(n): + yield i * 2 + +async def coroutine_helper(x): + \"\"\"Coroutine helper.\"\"\" + await asyncio.sleep(0.1) + return x * 3 + +async def another_coroutine_helper(x): + \"\"\"Another coroutine helper.\"\"\" + await asyncio.sleep(0.1) + return x * 4 + +async def async_entrypoint_with_generators(n): + \"\"\"Async entrypoint function that uses generators and coroutines.\"\"\" + results = [] + async for value in async_generator_helper(n): + results.append(value) + + final_result = await coroutine_helper(sum(results)) + another_result = await another_coroutine_helper(n) + return final_result + another_result +""") + + # Optimized version that doesn't use one of the coroutines + optimized_code = """ +```python:main.py +import asyncio + +async def async_generator_helper(n): + \"\"\"Async generator helper.\"\"\" + for i in range(n): + yield i * 2 + +async def coroutine_helper(x): + \"\"\"Coroutine helper.\"\"\" + await asyncio.sleep(0.1) + return x * 3 + +async def another_coroutine_helper(x): + \"\"\"Another coroutine helper - should be unused.\"\"\" + await asyncio.sleep(0.1) + return x * 4 + +async def async_entrypoint_with_generators(n): + \"\"\"Optimized async entrypoint that inlines one coroutine.\"\"\" + results = [] + async for value in async_generator_helper(n): + results.append(value) + + final_result = await coroutine_helper(sum(results)) + return final_result + n * 4 # Inlined another_coroutine_helper +``` +""" + + # Create test config + test_cfg = TestConfig( + tests_root=temp_dir / "tests", + tests_project_rootdir=temp_dir, + project_root_path=temp_dir, + test_framework="pytest", + pytest_cmd="pytest", + ) + + # Create FunctionToOptimize instance for async function + function_to_optimize = FunctionToOptimize( + file_path=main_file, + function_name="async_entrypoint_with_generators", + parents=[], + is_async=True + ) + + # Create function optimizer + optimizer = FunctionOptimizer( + function_to_optimize=function_to_optimize, + test_cfg=test_cfg, + function_to_optimize_source_code=main_file.read_text(), + ) + + # Get original code context + ctx_result = optimizer.get_code_optimization_context() + assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}" + + code_context = ctx_result.unwrap() + + # Test unused helper detection + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) + + # Should detect another_coroutine_helper as unused + unused_names = {uh.qualified_name for uh in unused_helpers} + expected_unused = {"another_coroutine_helper"} + + assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}" + + finally: + # Cleanup + import shutil + shutil.rmtree(temp_dir, ignore_errors=True)