Skip to content

Commit 593b7f2

Browse files
committed
another impl
1 parent f2379ce commit 593b7f2

File tree

3 files changed

+35
-6
lines changed

3 files changed

+35
-6
lines changed

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
test_framework: str,
5757
call_positions: list[CodePosition],
5858
mode: TestingMode = TestingMode.BEHAVIOR,
59+
is_async: bool = False,
5960
) -> None:
6061
self.mode: TestingMode = mode
6162
self.function_object = function
@@ -64,6 +65,7 @@ def __init__(
6465
self.module_path = module_path
6566
self.test_framework = test_framework
6667
self.call_positions = call_positions
68+
self.is_async = is_async
6769
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
6870
self.class_name = function.top_level_parent_name
6971

@@ -328,6 +330,7 @@ def inject_profiling_into_existing_test(
328330
tests_project_root: Path,
329331
test_framework: str,
330332
mode: TestingMode = TestingMode.BEHAVIOR,
333+
is_async: bool = False,
331334
) -> tuple[bool, str | None]:
332335
with test_path.open(encoding="utf8") as f:
333336
test_code = f.read()
@@ -342,7 +345,9 @@ def inject_profiling_into_existing_test(
342345
import_visitor.visit(tree)
343346
func = import_visitor.imported_as
344347

345-
tree = InjectPerfOnly(func, test_module_path, test_framework, call_positions, mode=mode).visit(tree)
348+
tree = InjectPerfOnly(func, test_module_path, test_framework, call_positions, mode=mode, is_async=is_async).visit(
349+
tree
350+
)
346351
new_imports = [
347352
ast.Import(names=[ast.alias(name="time")]),
348353
ast.Import(names=[ast.alias(name="gc")]),
@@ -354,11 +359,11 @@ def inject_profiling_into_existing_test(
354359
)
355360
if test_framework == "unittest":
356361
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
357-
tree.body = [*new_imports, create_wrapper_function(mode), *tree.body]
362+
tree.body = [*new_imports, create_wrapper_function(mode, is_async), *tree.body]
358363
return True, isort.code(ast.unparse(tree), float_to_top=True)
359364

360365

361-
def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.FunctionDef:
366+
def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR, is_async: bool = False) -> ast.FunctionDef:
362367
lineno = 1
363368
wrapper_body: list[ast.stmt] = [
364369
ast.Assign(
@@ -536,7 +541,15 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
536541
),
537542
ast.Assign(
538543
targets=[ast.Name(id="return_value", ctx=ast.Store())],
539-
value=ast.Call(
544+
value=ast.Await(
545+
value=ast.Call(
546+
func=ast.Name(id="wrapped", ctx=ast.Load()),
547+
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
548+
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
549+
)
550+
)
551+
if is_async
552+
else ast.Call(
540553
func=ast.Name(id="wrapped", ctx=ast.Load()),
541554
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
542555
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
@@ -703,7 +716,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
703716
),
704717
ast.Return(value=ast.Name(id="return_value", ctx=ast.Load()), lineno=lineno + 19),
705718
]
706-
return ast.FunctionDef(
719+
func_def = ast.FunctionDef(
707720
name="codeflash_wrap",
708721
args=ast.arguments(
709722
args=[
@@ -729,3 +742,13 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
729742
decorator_list=[],
730743
returns=None,
731744
)
745+
if is_async:
746+
return ast.AsyncFunctionDef(
747+
name="codeflash_wrap",
748+
args=func_def.args,
749+
body=func_def.body,
750+
lineno=func_def.lineno,
751+
decorator_list=func_def.decorator_list,
752+
returns=func_def.returns,
753+
)
754+
return func_def

codeflash/discovery/functions_to_optimize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
106106
# Check if the async function has a return statement and add it to the list
107107
if function_has_return_statement(node) and not function_is_a_property(node):
108108
self.functions.append(
109-
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
109+
FunctionToOptimize(
110+
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True
111+
)
110112
)
111113

112114
def generic_visit(self, node: ast.AST) -> None:
@@ -128,6 +130,7 @@ class FunctionToOptimize:
128130
parents: A list of parent scopes, which could be classes or functions.
129131
starting_line: The starting line number of the function in the file.
130132
ending_line: The ending line number of the function in the file.
133+
is_async: Whether this function is defined as async.
131134
132135
The qualified_name property provides the full name of the function, including
133136
any parent class or function names. The qualified_name_with_modules_from_root
@@ -140,6 +143,7 @@ class FunctionToOptimize:
140143
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
141144
starting_line: Optional[int] = None
142145
ending_line: Optional[int] = None
146+
is_async: bool = False
143147

144148
@property
145149
def top_level_parent_name(self) -> str:

codeflash/optimization/function_optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
785785
function_to_optimize=self.function_to_optimize,
786786
tests_project_root=self.test_cfg.tests_project_rootdir,
787787
test_framework=self.args.test_framework,
788+
is_async=self.function_to_optimize.is_async,
788789
)
789790
if not success:
790791
continue
@@ -795,6 +796,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
795796
function_to_optimize=self.function_to_optimize,
796797
tests_project_root=self.test_cfg.tests_project_rootdir,
797798
test_framework=self.args.test_framework,
799+
is_async=self.function_to_optimize.is_async,
798800
)
799801
if not success:
800802
continue

0 commit comments

Comments
 (0)