Skip to content

Commit 862415e

Browse files
committed
add support for async tests in edit_generated_tests
1 parent 69fa701 commit 862415e

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

codeflash/code_utils/edit_generated_tests.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
3434
self.context_stack.append(node.name)
3535
for inner_node in ast.walk(node):
36-
if isinstance(inner_node, ast.FunctionDef):
36+
if isinstance(inner_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
3737
self.visit_FunctionDef(inner_node)
3838
self.context_stack.pop()
3939
return node
@@ -50,6 +50,12 @@ def get_comment(self, match_key: str) -> str:
5050
return f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
5151

5252
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
53+
return self._process_function_def(node)
54+
55+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
56+
return self._process_function_def(node)
57+
58+
def _process_function_def(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.FunctionDef | ast.AsyncFunctionDef:
5359
self.context_stack.append(node.name)
5460
i = len(node.body) - 1
5561
test_qualified_name = ".".join(self.context_stack)

tests/test_async_run_and_parse_tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ async def test_async_sort():
5454

5555
assert source_success
5656
assert instrumented_source is not None
57-
assert '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_behavior_async\n\n\n@codeflash_behavior_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n # Add some async delay to simulate async work\n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n''' in instrumented_source
57+
assert '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_behavior_async\n\n\n@codeflash_behavior_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n''' in instrumented_source
5858

5959
# Write the instrumented source back
6060
fto_path.write_text(instrumented_source, "utf-8")
@@ -287,7 +287,7 @@ async def test_async_perf():
287287

288288
assert source_success
289289
assert instrumented_source is not None
290-
assert '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_performance_async\n\n\n@codeflash_performance_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n # Add some async delay to simulate async work\n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n''' == instrumented_source
290+
assert '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_performance_async\n\n\n@codeflash_performance_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n''' == instrumented_source
291291

292292
fto_path.write_text(instrumented_source, "utf-8")
293293

0 commit comments

Comments
 (0)