Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,6 @@ def validate_python_code(code: str) -> str:
return code


def has_any_async_functions(code: str) -> bool:
try:
module = ast.parse(code)
except SyntaxError:
return False
return any(isinstance(node, ast.AsyncFunctionDef) for node in ast.walk(module))


def cleanup_paths(paths: list[Path]) -> None:
for path in paths:
if path and path.exists():
Expand Down
12 changes: 9 additions & 3 deletions codeflash/code_utils/static_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,19 @@ def get_first_top_level_object_def_ast(

def get_first_top_level_function_or_method_ast(
function_name: str, parents: list[FunctionParent], node: ast.AST
) -> ast.FunctionDef | None:
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
if not parents:
return get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node)
result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node)
if result is not None:
return result
return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, node)
if parents[0].type == "ClassDef" and (
class_node := get_first_top_level_object_def_ast(parents[0].name, ast.ClassDef, node)
):
return get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node)
result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node)
if result is not None:
return result
return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, class_node)
return None


Expand Down
32 changes: 30 additions & 2 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
parents=list(reversed(ast_parents)),
starting_line=pos.start.line,
ending_line=pos.end.line,
is_async=bool(node.asynchronous),
)
)

Expand All @@ -103,6 +104,15 @@ def visit_FunctionDef(self, node: FunctionDef) -> None:
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
)

def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
# Check if the async function has a return statement and add it to the list
if function_has_return_statement(node) and not function_is_a_property(node):
self.functions.append(
FunctionToOptimize(
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True
)
)

def generic_visit(self, node: ast.AST) -> None:
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
self.ast_path.append(FunctionParent(node.name, node.__class__.__name__))
Expand All @@ -122,6 +132,7 @@ class FunctionToOptimize:
parents: A list of parent scopes, which could be classes or functions.
starting_line: The starting line number of the function in the file.
ending_line: The ending line number of the function in the file.
is_async: Whether this function is defined as async.

The qualified_name property provides the full name of the function, including
any parent class or function names. The qualified_name_with_modules_from_root
Expand All @@ -134,6 +145,7 @@ class FunctionToOptimize:
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
starting_line: Optional[int] = None
ending_line: Optional[int] = None
is_async: bool = False

@property
def top_level_parent_name(self) -> str:
Expand Down Expand Up @@ -402,11 +414,27 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
)
)

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
if self.class_name is None and node.name == self.function_name:
self.is_top_level = True
self.function_has_args = any(
(
bool(node.args.args),
bool(node.args.kwonlyargs),
bool(node.args.kwarg),
bool(node.args.posonlyargs),
bool(node.args.vararg),
)
)

def visit_ClassDef(self, node: ast.ClassDef) -> None:
# iterate over the class methods
if node.name == self.class_name:
for body_node in node.body:
if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name:
if (
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
and body_node.name == self.function_name
):
self.is_top_level = True
if any(
isinstance(decorator, ast.Name) and decorator.id == "classmethod"
Expand All @@ -424,7 +452,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
# This way, if we don't have the class name, we can still find the static method
for body_node in node.body:
if (
isinstance(body_node, ast.FunctionDef)
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
and body_node.name == self.function_name
and body_node.lineno in {self.line_no, self.line_no + 1}
and any(
Expand Down
8 changes: 1 addition & 7 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
diff_length,
file_name_from_test_module_name,
get_run_tmp_file,
has_any_async_functions,
module_name_from_file_path,
restore_conftest,
)
Expand Down Expand Up @@ -189,7 +188,7 @@ def __init__(
test_cfg: TestConfig,
function_to_optimize_source_code: str = "",
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
function_to_optimize_ast: ast.FunctionDef | None = None,
function_to_optimize_ast: ast.FunctionDef | ast.AsyncFunctionDef | None = None,
aiservice_client: AiServiceClient | None = None,
function_benchmark_timings: dict[BenchmarkKey, int] | None = None,
total_benchmark_timings: dict[BenchmarkKey, int] | None = None,
Expand Down Expand Up @@ -248,11 +247,6 @@ def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[P
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code

async_code = any(
has_any_async_functions(code_string.code) for code_string in code_context.read_writable_code.code_strings
)
if async_code:
return Failure("Codeflash does not support async functions in the code to optimize.")
# Random here means that we still attempt optimization with a fractional chance to see if
# last time we could not find an optimization, maybe this time we do.
# Random is before as a performance optimization, swapping the two 'and' statements has the same effect
Expand Down
6 changes: 3 additions & 3 deletions codeflash/verification/coverage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ def load_from_sqlite_database(

cov = Coverage(data_file=database_path, config_file=config_path, data_suffix=True, auto_data=True, branch=True)

if not database_path.stat().st_size or not database_path.exists():
if not database_path.exists() or not database_path.stat().st_size:
logger.debug(f"Coverage database {database_path} is empty or does not exist")
sentry_sdk.capture_message(f"Coverage database {database_path} is empty or does not exist")
return CoverageUtils.create_empty(source_code_path, function_name, code_context)
return CoverageData.create_empty(source_code_path, function_name, code_context)
cov.load()

reporter = JsonReporter(cov)
Expand All @@ -51,7 +51,7 @@ def load_from_sqlite_database(
reporter.report(morfs=[source_code_path.as_posix()], outfile=f)
except NoDataError:
sentry_sdk.capture_message(f"No coverage data found for {function_name} in {source_code_path}")
return CoverageUtils.create_empty(source_code_path, function_name, code_context)
return CoverageData.create_empty(source_code_path, function_name, code_context)
with temp_json_file.open() as f:
original_coverage_data = json.load(f)

Expand Down
Loading
Loading