diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index dfd79a76b..8d16b6d3d 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -251,14 +251,6 @@ def validate_python_code(code: str) -> str: return code -def has_any_async_functions(code: str) -> bool: - try: - module = ast.parse(code) - except SyntaxError: - return False - return any(isinstance(node, ast.AsyncFunctionDef) for node in ast.walk(module)) - - def cleanup_paths(paths: list[Path]) -> None: for path in paths: if path and path.exists(): diff --git a/codeflash/code_utils/static_analysis.py b/codeflash/code_utils/static_analysis.py index dbddb59f5..0151e29e7 100644 --- a/codeflash/code_utils/static_analysis.py +++ b/codeflash/code_utils/static_analysis.py @@ -128,13 +128,19 @@ def get_first_top_level_object_def_ast( def get_first_top_level_function_or_method_ast( function_name: str, parents: list[FunctionParent], node: ast.AST -) -> ast.FunctionDef | None: +) -> ast.FunctionDef | ast.AsyncFunctionDef | None: if not parents: - return get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node) + result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node) + if result is not None: + return result + return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, node) if parents[0].type == "ClassDef" and ( class_node := get_first_top_level_object_def_ast(parents[0].name, ast.ClassDef, node) ): - return get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node) + result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node) + if result is not None: + return result + return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, class_node) return None diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 27a46af0a..94d265968 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -86,6 +86,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None: parents=list(reversed(ast_parents)), starting_line=pos.start.line, ending_line=pos.end.line, + is_async=bool(node.asynchronous), ) ) @@ -103,6 +104,15 @@ def visit_FunctionDef(self, node: FunctionDef) -> None: FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:]) ) + def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None: + # Check if the async function has a return statement and add it to the list + if function_has_return_statement(node) and not function_is_a_property(node): + self.functions.append( + FunctionToOptimize( + function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True + ) + ) + def generic_visit(self, node: ast.AST) -> None: if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)): self.ast_path.append(FunctionParent(node.name, node.__class__.__name__)) @@ -122,6 +132,7 @@ class FunctionToOptimize: parents: A list of parent scopes, which could be classes or functions. starting_line: The starting line number of the function in the file. ending_line: The ending line number of the function in the file. + is_async: Whether this function is defined as async. The qualified_name property provides the full name of the function, including any parent class or function names. The qualified_name_with_modules_from_root @@ -134,6 +145,7 @@ class FunctionToOptimize: parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef] starting_line: Optional[int] = None ending_line: Optional[int] = None + is_async: bool = False @property def top_level_parent_name(self) -> str: @@ -402,11 +414,27 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: ) ) + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + if self.class_name is None and node.name == self.function_name: + self.is_top_level = True + self.function_has_args = any( + ( + bool(node.args.args), + bool(node.args.kwonlyargs), + bool(node.args.kwarg), + bool(node.args.posonlyargs), + bool(node.args.vararg), + ) + ) + def visit_ClassDef(self, node: ast.ClassDef) -> None: # iterate over the class methods if node.name == self.class_name: for body_node in node.body: - if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name: + if ( + isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and body_node.name == self.function_name + ): self.is_top_level = True if any( isinstance(decorator, ast.Name) and decorator.id == "classmethod" @@ -424,7 +452,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: # This way, if we don't have the class name, we can still find the static method for body_node in node.body: if ( - isinstance(body_node, ast.FunctionDef) + isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and body_node.name == self.function_name and body_node.lineno in {self.line_no, self.line_no + 1} and any( diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index c523dcbce..0e42b36f3 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -36,7 +36,6 @@ diff_length, file_name_from_test_module_name, get_run_tmp_file, - has_any_async_functions, module_name_from_file_path, restore_conftest, ) @@ -189,7 +188,7 @@ def __init__( test_cfg: TestConfig, function_to_optimize_source_code: str = "", function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None, - function_to_optimize_ast: ast.FunctionDef | None = None, + function_to_optimize_ast: ast.FunctionDef | ast.AsyncFunctionDef | None = None, aiservice_client: AiServiceClient | None = None, function_benchmark_timings: dict[BenchmarkKey, int] | None = None, total_benchmark_timings: dict[BenchmarkKey, int] | None = None, @@ -248,11 +247,6 @@ def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[P helper_code = f.read() original_helper_code[helper_function_path] = helper_code - async_code = any( - has_any_async_functions(code_string.code) for code_string in code_context.read_writable_code.code_strings - ) - if async_code: - return Failure("Codeflash does not support async functions in the code to optimize.") # Random here means that we still attempt optimization with a fractional chance to see if # last time we could not find an optimization, maybe this time we do. # Random is before as a performance optimization, swapping the two 'and' statements has the same effect diff --git a/codeflash/verification/coverage_utils.py b/codeflash/verification/coverage_utils.py index c9044a44d..63ba2c9d6 100644 --- a/codeflash/verification/coverage_utils.py +++ b/codeflash/verification/coverage_utils.py @@ -38,10 +38,10 @@ def load_from_sqlite_database( cov = Coverage(data_file=database_path, config_file=config_path, data_suffix=True, auto_data=True, branch=True) - if not database_path.stat().st_size or not database_path.exists(): + if not database_path.exists() or not database_path.stat().st_size: logger.debug(f"Coverage database {database_path} is empty or does not exist") sentry_sdk.capture_message(f"Coverage database {database_path} is empty or does not exist") - return CoverageUtils.create_empty(source_code_path, function_name, code_context) + return CoverageData.create_empty(source_code_path, function_name, code_context) cov.load() reporter = JsonReporter(cov) @@ -51,7 +51,7 @@ def load_from_sqlite_database( reporter.report(morfs=[source_code_path.as_posix()], outfile=f) except NoDataError: sentry_sdk.capture_message(f"No coverage data found for {function_name} in {source_code_path}") - return CoverageUtils.create_empty(source_code_path, function_name, code_context) + return CoverageData.create_empty(source_code_path, function_name, code_context) with temp_json_file.open() as f: original_coverage_data = json.load(f) diff --git a/tests/test_async_function_discovery.py b/tests/test_async_function_discovery.py new file mode 100644 index 000000000..259d9ee24 --- /dev/null +++ b/tests/test_async_function_discovery.py @@ -0,0 +1,286 @@ +import tempfile +from pathlib import Path +import pytest + +from codeflash.discovery.functions_to_optimize import ( + find_all_functions_in_file, + get_functions_to_optimize, + inspect_top_level_functions_or_methods, +) +from codeflash.verification.verification_utils import TestConfig + + +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as temp: + yield Path(temp) + + +def test_async_function_detection(temp_dir): + async_function = """ +async def async_function_with_return(): + await some_async_operation() + return 42 + +async def async_function_without_return(): + await some_async_operation() + print("No return") + +def regular_function(): + return 10 +""" + + file_path = temp_dir / "test_file.py" + file_path.write_text(async_function) + functions_found = find_all_functions_in_file(file_path) + + function_names = [fn.function_name for fn in functions_found[file_path]] + + assert "async_function_with_return" in function_names + assert "regular_function" in function_names + assert "async_function_without_return" not in function_names + + +def test_async_method_in_class(temp_dir): + code_with_async_method = """ +class AsyncClass: + async def async_method(self): + await self.do_something() + return "result" + + async def async_method_no_return(self): + await self.do_something() + pass + + def sync_method(self): + return "sync result" +""" + + file_path = temp_dir / "test_file.py" + file_path.write_text(code_with_async_method) + functions_found = find_all_functions_in_file(file_path) + + found_functions = functions_found[file_path] + function_names = [fn.function_name for fn in found_functions] + qualified_names = [fn.qualified_name for fn in found_functions] + + assert "async_method" in function_names + assert "AsyncClass.async_method" in qualified_names + + assert "sync_method" in function_names + assert "AsyncClass.sync_method" in qualified_names + + assert "async_method_no_return" not in function_names + + +def test_nested_async_functions(temp_dir): + nested_async = """ +async def outer_async(): + async def inner_async(): + return "inner" + + result = await inner_async() + return result + +def outer_sync(): + async def inner_async(): + return "inner from sync" + + return inner_async +""" + + file_path = temp_dir / "test_file.py" + file_path.write_text(nested_async) + functions_found = find_all_functions_in_file(file_path) + + function_names = [fn.function_name for fn in functions_found[file_path]] + + assert "outer_async" in function_names + assert "outer_sync" in function_names + assert "inner_async" not in function_names + + +def test_async_staticmethod_and_classmethod(temp_dir): + async_decorators = """ +class MyClass: + @staticmethod + async def async_static_method(): + await some_operation() + return "static result" + + @classmethod + async def async_class_method(cls): + await cls.some_operation() + return "class result" + + @property + async def async_property(self): + return await self.get_value() +""" + + file_path = temp_dir / "test_file.py" + file_path.write_text(async_decorators) + functions_found = find_all_functions_in_file(file_path) + + function_names = [fn.function_name for fn in functions_found[file_path]] + + assert "async_static_method" in function_names + assert "async_class_method" in function_names + + assert "async_property" not in function_names + + +def test_async_generator_functions(temp_dir): + async_generators = """ +async def async_generator_with_return(): + for i in range(10): + yield i + return "done" + +async def async_generator_no_return(): + for i in range(10): + yield i + +async def regular_async_with_return(): + result = await compute() + return result +""" + + file_path = temp_dir / "test_file.py" + file_path.write_text(async_generators) + functions_found = find_all_functions_in_file(file_path) + + function_names = [fn.function_name for fn in functions_found[file_path]] + + assert "async_generator_with_return" in function_names + assert "regular_async_with_return" in function_names + assert "async_generator_no_return" not in function_names + + +def test_inspect_async_top_level_functions(temp_dir): + code = """ +async def top_level_async(): + return 42 + +class AsyncContainer: + async def async_method(self): + async def nested_async(): + return 1 + return await nested_async() + + @staticmethod + async def async_static(): + return "static" + + @classmethod + async def async_classmethod(cls): + return "classmethod" +""" + + file_path = temp_dir / "test_file.py" + file_path.write_text(code) + + result = inspect_top_level_functions_or_methods(file_path, "top_level_async") + assert result.is_top_level + + result = inspect_top_level_functions_or_methods(file_path, "async_method", class_name="AsyncContainer") + assert result.is_top_level + + result = inspect_top_level_functions_or_methods(file_path, "nested_async", class_name="AsyncContainer") + assert not result.is_top_level + + result = inspect_top_level_functions_or_methods(file_path, "async_static", class_name="AsyncContainer") + assert result.is_top_level + assert result.is_staticmethod + + result = inspect_top_level_functions_or_methods(file_path, "async_classmethod", class_name="AsyncContainer") + assert result.is_top_level + assert result.is_classmethod + + +def test_get_functions_to_optimize_with_async(temp_dir): + mixed_code = """ +async def async_func_one(): + return await operation_one() + +def sync_func_one(): + return operation_one() + +async def async_func_two(): + print("no return") + +class MixedClass: + async def async_method(self): + return await self.operation() + + def sync_method(self): + return self.operation() +""" + + file_path = temp_dir / "test_file.py" + file_path.write_text(mixed_code) + + test_config = TestConfig( + tests_root="tests", + project_root_path=".", + test_framework="pytest", + tests_project_rootdir=Path() + ) + + functions, functions_count, _ = get_functions_to_optimize( + optimize_all=None, + replay_test=None, + file=file_path, + only_get_this_function=None, + test_cfg=test_config, + ignore_paths=[], + project_root=file_path.parent, + module_root=file_path.parent, + ) + + assert functions_count == 4 + + function_names = [fn.function_name for fn in functions[file_path]] + assert "async_func_one" in function_names + assert "sync_func_one" in function_names + assert "async_method" in function_names + assert "sync_method" in function_names + + assert "async_func_two" not in function_names + + +def test_async_function_parents(temp_dir): + complex_structure = """ +class OuterClass: + async def outer_method(self): + return 1 + + class InnerClass: + async def inner_method(self): + return 2 + +async def module_level_async(): + class LocalClass: + async def local_method(self): + return 3 + return LocalClass() +""" + + file_path = temp_dir / "test_file.py" + file_path.write_text(complex_structure) + functions_found = find_all_functions_in_file(file_path) + + found_functions = functions_found[file_path] + + for fn in found_functions: + if fn.function_name == "outer_method": + assert len(fn.parents) == 1 + assert fn.parents[0].name == "OuterClass" + assert fn.qualified_name == "OuterClass.outer_method" + elif fn.function_name == "inner_method": + assert len(fn.parents) == 2 + assert fn.parents[0].name == "OuterClass" + assert fn.parents[1].name == "InnerClass" + elif fn.function_name == "module_level_async": + assert len(fn.parents) == 0 + assert fn.qualified_name == "module_level_async" \ No newline at end of file diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 3a7de5d1c..d82a4728b 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -1800,9 +1800,10 @@ def get_system_details(): # Set up the optimizer file_path = main_file_path.resolve() + project_root = package_dir.resolve() opt = Optimizer( Namespace( - project_root=package_dir.resolve(), + project_root=project_root, disable_telemetry=True, tests_root="tests", test_framework="pytest", @@ -1826,8 +1827,10 @@ def get_system_details(): read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context # The expected contexts + # Resolve both paths to handle symlink issues on macOS + relative_path = file_path.relative_to(project_root) expected_read_write_context = f""" -```python:{main_file_path.relative_to(opt.args.project_root)} +```python:{relative_path} import utility_module class Calculator: @@ -2045,9 +2048,10 @@ def get_system_details(): # Set up the optimizer file_path = main_file_path.resolve() + project_root = package_dir.resolve() opt = Optimizer( Namespace( - project_root=package_dir.resolve(), + project_root=project_root, disable_telemetry=True, tests_root="tests", test_framework="pytest", @@ -2070,6 +2074,7 @@ def get_system_details(): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code # The expected contexts + relative_path = file_path.relative_to(project_root) expected_read_write_context = f""" ```python:utility_module.py # Function that will be used in the main code @@ -2096,7 +2101,7 @@ def select_precision(precision, fallback_precision): else: return DEFAULT_PRECISION ``` -```python:{main_file_path.relative_to(opt.args.project_root)} +```python:{relative_path} import utility_module class Calculator: diff --git a/tests/test_code_utils.py b/tests/test_code_utils.py index 0ab78d2ef..4fc28bea2 100644 --- a/tests/test_code_utils.py +++ b/tests/test_code_utils.py @@ -17,7 +17,7 @@ is_class_defined_in_file, module_name_from_file_path, path_belongs_to_site_packages, - has_any_async_functions, + validate_python_code, ) from codeflash.code_utils.concolic_utils import clean_concolic_tests from codeflash.code_utils.coverage_utils import generate_candidates, prepare_coverage_files @@ -445,25 +445,41 @@ def test_Grammar_copy(): assert cleaned_code == expected_cleaned_code.strip() -def test_has_any_async_functions_with_async_code() -> None: - code = """ -def normal_function(): - pass +def test_validate_python_code_valid() -> None: + code = "def hello():\n return 'world'" + result = validate_python_code(code) + assert result == code -async def async_function(): - pass -""" - result = has_any_async_functions(code) - assert result is True +def test_validate_python_code_invalid() -> None: + code = "def hello(:\n return 'world'" + with pytest.raises(ValueError, match="Invalid Python code"): + validate_python_code(code) -def test_has_any_async_functions_without_async_code() -> None: - code = """ -def normal_function(): - pass -def another_function(): - pass +def test_validate_python_code_empty() -> None: + code = "" + result = validate_python_code(code) + assert result == code + + +def test_validate_python_code_complex_invalid() -> None: + code = "if True\n print('missing colon')" + with pytest.raises(ValueError, match="Invalid Python code.*line 1.*column 8"): + validate_python_code(code) + + +def test_validate_python_code_valid_complex() -> None: + code = """ +def calculate(a, b): + if a > b: + return a + b + else: + return a * b + +class MyClass: + def __init__(self): + self.value = 42 """ - result = has_any_async_functions(code) - assert result is False + result = validate_python_code(code) + assert result == code