Skip to content

Commit 759fdb1

Browse files
committed
Complete async test instrumentation and utilities implementation
- Added comprehensive async test instrumentation (AsyncCallInstrumenter class) - Implemented async decorator functions (add_async_decorator_to_function, instrument_source_module_with_async_decorators) - Added async wrapper decorators (codeflash_behavior_async, codeflash_performance_async) - Updated edit_generated_tests.py to handle AsyncFunctionDef nodes in test parsing - Updated coverage_utils.py to include async functions in coverage analysis
1 parent 3e96474 commit 759fdb1

File tree

3 files changed

+388
-8
lines changed

3 files changed

+388
-8
lines changed

codeflash/code_utils/coverage_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ def extract_dependent_function(main_function: str, code_context: CodeOptimizatio
1414
"""Extract the single dependent function from the code context excluding the main function."""
1515
ast_tree = ast.parse(code_context.testgen_context_code)
1616

17-
dependent_functions = {node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)}
17+
dependent_functions = {
18+
node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
19+
}
1820

1921
if main_function in dependent_functions:
2022
dependent_functions.discard(main_function)

codeflash/code_utils/edit_generated_tests.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@ def __init__(
3232

3333
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
3434
self.context_stack.append(node.name)
35-
for inner_node in ast.walk(node):
35+
for inner_node in node.body:
3636
if isinstance(inner_node, ast.FunctionDef):
3737
self.visit_FunctionDef(inner_node)
38+
elif isinstance(inner_node, ast.AsyncFunctionDef):
39+
self.visit_AsyncFunctionDef(inner_node)
3840
self.context_stack.pop()
3941
return node
4042

@@ -50,6 +52,14 @@ def get_comment(self, match_key: str) -> str:
5052
return f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
5153

5254
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
55+
self._process_function_def_common(node)
56+
return node
57+
58+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
59+
self._process_function_def_common(node)
60+
return node
61+
62+
def _process_function_def_common(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
5363
self.context_stack.append(node.name)
5464
i = len(node.body) - 1
5565
test_qualified_name = ".".join(self.context_stack)
@@ -60,8 +70,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
6070
j = len(line_node.body) - 1
6171
while j >= 0:
6272
compound_line_node: ast.stmt = line_node.body[j]
63-
internal_node: ast.AST
64-
for internal_node in ast.walk(compound_line_node):
73+
nodes_to_check = [compound_line_node]
74+
nodes_to_check.extend(getattr(compound_line_node, "body", []))
75+
for internal_node in nodes_to_check:
6576
if isinstance(internal_node, (ast.stmt, ast.Assign)):
6677
inv_id = str(i) + "_" + str(j)
6778
match_key = key + "#" + inv_id
@@ -75,7 +86,6 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
7586
self.results[line_node.lineno] = self.get_comment(match_key)
7687
i -= 1
7788
self.context_stack.pop()
78-
return node
7989

8090

8191
def get_fn_call_linenos(
@@ -201,7 +211,7 @@ def remove_functions_from_generated_tests(
201211
for generated_test in generated_tests.generated_tests:
202212
for test_function in test_functions_to_remove:
203213
function_pattern = re.compile(
204-
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)",
214+
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?(async\s+)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\n(async\s+)?def\s|$)",
205215
re.DOTALL,
206216
)
207217

0 commit comments

Comments
 (0)