Skip to content

Commit 018581f

Browse files
committed
skip async funcs in the helper code
1 parent 69e43dd commit 018581f

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

codeflash/code_utils/code_utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ def file_name_from_test_module_name(test_module_name: str, base_dir: Path) -> Pa
4646
def get_imports_from_file(
4747
file_path: Path | None = None, file_string: str | None = None, file_ast: ast.AST | None = None
4848
) -> list[ast.Import | ast.ImportFrom]:
49-
assert (
50-
sum([file_path is not None, file_string is not None, file_ast is not None]) == 1
51-
), "Must provide exactly one of file_path, file_string, or file_ast"
49+
assert sum([file_path is not None, file_string is not None, file_ast is not None]) == 1, (
50+
"Must provide exactly one of file_path, file_string, or file_ast"
51+
)
5252
if file_path:
5353
with file_path.open(encoding="utf8") as file:
5454
file_string = file.read()
@@ -107,6 +107,14 @@ def validate_python_code(code: str) -> str:
107107
return code
108108

109109

110+
def has_any_async_functions(code: str) -> bool:
111+
try:
112+
module = ast.parse(code)
113+
except SyntaxError:
114+
return False
115+
return any(isinstance(node, ast.AsyncFunctionDef) for node in ast.walk(module))
116+
117+
110118
def cleanup_paths(paths: list[Path]) -> None:
111119
for path in paths:
112120
path.unlink(missing_ok=True)

codeflash/optimization/function_optimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
file_name_from_test_module_name,
2929
get_run_tmp_file,
3030
module_name_from_file_path,
31+
has_any_async_functions,
3132
)
3233
from codeflash.code_utils.config_consts import (
3334
INDIVIDUAL_TESTCASE_TIMEOUT,
@@ -136,8 +137,8 @@ def optimize_function(self) -> Result[BestOptimization, str]:
136137
with helper_function_path.open(encoding="utf8") as f:
137138
helper_code = f.read()
138139
original_helper_code[helper_function_path] = helper_code
139-
140-
logger.info("Code to be optimized:")
140+
if has_any_async_functions(code_context.code_to_optimize_with_helpers):
141+
return Failure("Codeflash does not support async functions in the code to optimize.")
141142
code_print(code_context.read_writable_code)
142143

143144
for module_abspath, helper_code_source in original_helper_code.items():

codeflash/optimization/optimizer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def create_function_optimizer(
6060
function_to_optimize_ast=function_to_optimize_ast,
6161
aiservice_client=self.aiservice_client,
6262
args=self.args,
63-
6463
)
6564

6665
def run(self) -> None:
@@ -140,6 +139,7 @@ def run(self) -> None:
140139
validated_original_code[analysis.file_path] = ValidCode(
141140
source_code=callee_original_code, normalized_code=normalized_callee_original_code
142141
)
142+
143143
if has_syntax_error:
144144
continue
145145

@@ -149,7 +149,7 @@ def run(self) -> None:
149149
f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: "
150150
f"{function_to_optimize.qualified_name}"
151151
)
152-
152+
console.rule()
153153
if not (
154154
function_to_optimize_ast := get_first_top_level_function_or_method_ast(
155155
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
@@ -160,9 +160,11 @@ def run(self) -> None:
160160
f"Skipping optimization."
161161
)
162162
continue
163-
164163
function_optimizer = self.create_function_optimizer(
165-
function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code
164+
function_to_optimize,
165+
function_to_optimize_ast,
166+
function_to_tests,
167+
validated_original_code[original_module_path].source_code,
166168
)
167169
best_optimization = function_optimizer.optimize_function()
168170
if is_successful(best_optimization):
@@ -192,7 +194,6 @@ def run(self) -> None:
192194
get_run_tmp_file.tmpdir.cleanup()
193195

194196

195-
196197
def run_with_args(args: Namespace) -> None:
197198
optimizer = Optimizer(args)
198199
optimizer.run()

0 commit comments

Comments
 (0)