diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index d4456bb62..29c0be556 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -46,9 +46,9 @@ def file_name_from_test_module_name(test_module_name: str, base_dir: Path) -> Pa def get_imports_from_file( file_path: Path | None = None, file_string: str | None = None, file_ast: ast.AST | None = None ) -> list[ast.Import | ast.ImportFrom]: - assert ( - sum([file_path is not None, file_string is not None, file_ast is not None]) == 1 - ), "Must provide exactly one of file_path, file_string, or file_ast" + assert sum([file_path is not None, file_string is not None, file_ast is not None]) == 1, ( + "Must provide exactly one of file_path, file_string, or file_ast" + ) if file_path: with file_path.open(encoding="utf8") as file: file_string = file.read() @@ -107,6 +107,14 @@ def validate_python_code(code: str) -> str: return code +def has_any_async_functions(code: str) -> bool: + try: + module = ast.parse(code) + except SyntaxError: + return False + return any(isinstance(node, ast.AsyncFunctionDef) for node in ast.walk(module)) + + def cleanup_paths(paths: list[Path]) -> None: for path in paths: path.unlink(missing_ok=True) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 7b067a094..f2d58f9f2 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -28,6 +28,7 @@ file_name_from_test_module_name, get_run_tmp_file, module_name_from_file_path, + has_any_async_functions, ) from codeflash.code_utils.config_consts import ( INDIVIDUAL_TESTCASE_TIMEOUT, @@ -136,8 +137,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: with helper_function_path.open(encoding="utf8") as f: helper_code = f.read() original_helper_code[helper_function_path] = helper_code - - logger.info("Code to be optimized:") + if has_any_async_functions(code_context.code_to_optimize_with_helpers): + return Failure("Codeflash does not support async functions in the code to optimize.") code_print(code_context.read_writable_code) for module_abspath, helper_code_source in original_helper_code.items(): diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index cae78a153..ed6caa669 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -60,7 +60,6 @@ def create_function_optimizer( function_to_optimize_ast=function_to_optimize_ast, aiservice_client=self.aiservice_client, args=self.args, - ) def run(self) -> None: @@ -140,6 +139,7 @@ def run(self) -> None: validated_original_code[analysis.file_path] = ValidCode( source_code=callee_original_code, normalized_code=normalized_callee_original_code ) + if has_syntax_error: continue @@ -149,7 +149,7 @@ def run(self) -> None: f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: " f"{function_to_optimize.qualified_name}" ) - + console.rule() if not ( function_to_optimize_ast := get_first_top_level_function_or_method_ast( function_to_optimize.function_name, function_to_optimize.parents, original_module_ast @@ -160,9 +160,11 @@ def run(self) -> None: f"Skipping optimization." ) continue - function_optimizer = self.create_function_optimizer( - function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code + function_to_optimize, + function_to_optimize_ast, + function_to_tests, + validated_original_code[original_module_path].source_code, ) best_optimization = function_optimizer.optimize_function() if is_successful(best_optimization): @@ -192,7 +194,6 @@ def run(self) -> None: get_run_tmp_file.tmpdir.cleanup() - def run_with_args(args: Namespace) -> None: optimizer = Optimizer(args) optimizer.run() diff --git a/tests/test_code_utils.py b/tests/test_code_utils.py index 85719f4f9..a10f50a56 100644 --- a/tests/test_code_utils.py +++ b/tests/test_code_utils.py @@ -17,6 +17,7 @@ is_class_defined_in_file, module_name_from_file_path, path_belongs_to_site_packages, + has_any_async_functions, ) from codeflash.code_utils.concolic_utils import clean_concolic_tests from codeflash.code_utils.coverage_utils import generate_candidates, prepare_coverage_files @@ -441,3 +442,27 @@ def test_Grammar_copy(): Grammar.copy(Grammar()) """ assert cleaned_code == expected_cleaned_code.strip() + + +def test_has_any_async_functions_with_async_code() -> None: + code = """ +def normal_function(): + pass + +async def async_function(): + pass +""" + result = has_any_async_functions(code) + assert result is True + + +def test_has_any_async_functions_without_async_code() -> None: + code = """ +def normal_function(): + pass + +def another_function(): + pass +""" + result = has_any_async_functions(code) + assert result is False