Skip to content

Commit 14af1a8

Browse files
committed
implementation
1 parent 4c6910f commit 14af1a8

File tree

6 files changed

+1406
-631
lines changed

6 files changed

+1406
-631
lines changed

codeflash/code_utils/codeflash_wrap_decorator.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
import os
66
import pickle
77
import time
8-
from collections.abc import Awaitable
98
from functools import wraps
109
from pathlib import Path
1110
from typing import Any, Callable, TypeVar
1211

1312
from codeflash.code_utils.code_utils import get_run_tmp_file
1413

15-
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
14+
F = TypeVar("F", bound=Callable[..., Any])
1615

1716

1817
def extract_test_context_from_frame() -> tuple[str, str | None, str]:
@@ -33,9 +32,10 @@ def extract_test_context_from_frame() -> tuple[str, str | None, str]:
3332
del frame
3433

3534

36-
def codeflash_behavior_async(func: F) -> F:
35+
def codeflash_behavior_async(func):
36+
"""Async decorator for behavior analysis - collects timing data and function inputs/outputs."""
3737
function_name = func.__name__
38-
line_id = f"{function_name}_{func.__code__.co_firstlineno}"
38+
line_id = f"{func.__name__}_{func.__code__.co_firstlineno}"
3939
loop_index = int(os.environ.get("CODEFLASH_LOOP_INDEX", "1"))
4040

4141
@wraps(func)
@@ -106,3 +106,63 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
106106
return return_value
107107

108108
return async_wrapper
109+
110+
111+
def codeflash_performance_async(func):
112+
"""Async decorator for performance analysis - lightweight timing measurements only."""
113+
function_name = func.__name__
114+
line_id = f"{func.__name__}_{func.__code__.co_firstlineno}"
115+
loop_index = int(os.environ.get("CODEFLASH_LOOP_INDEX", "1"))
116+
117+
@wraps(func)
118+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
119+
test_module_name, test_class_name, test_name = extract_test_context_from_frame()
120+
121+
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
122+
123+
if not hasattr(async_wrapper, "index"):
124+
async_wrapper.index = {}
125+
if test_id in async_wrapper.index:
126+
async_wrapper.index[test_id] += 1
127+
else:
128+
async_wrapper.index[test_id] = 0
129+
130+
codeflash_test_index = async_wrapper.index[test_id]
131+
invocation_id = f"{line_id}_{codeflash_test_index}"
132+
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
133+
134+
print(f"!$######{test_stdout_tag}######$!")
135+
136+
exception = None
137+
gc.disable()
138+
try:
139+
counter = time.perf_counter_ns()
140+
ret = func(*args, **kwargs)
141+
142+
if inspect.isawaitable(ret):
143+
counter = time.perf_counter_ns()
144+
return_value = await ret
145+
else:
146+
return_value = ret
147+
148+
codeflash_duration = time.perf_counter_ns() - counter
149+
except Exception as e:
150+
codeflash_duration = time.perf_counter_ns() - counter
151+
exception = e
152+
finally:
153+
gc.enable()
154+
155+
# For performance mode, include timing in the output tag like sync functions do
156+
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
157+
158+
if exception:
159+
raise exception
160+
return return_value
161+
162+
return async_wrapper
163+
164+
165+
# Convenience function to get the appropriate decorator based on mode
166+
def codeflash_trace_async(func):
167+
"""Default async decorator - uses behavior analysis mode."""
168+
return codeflash_behavior_async(func)

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 66 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -322,13 +322,37 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
322322
)
323323

324324

325+
def instrument_source_module_with_async_decorators(
326+
source_path: Path,
327+
function_to_optimize: FunctionToOptimize,
328+
mode: TestingMode = TestingMode.BEHAVIOR,
329+
) -> tuple[bool, str | None]:
330+
if not function_to_optimize.is_async:
331+
return False, None
332+
333+
try:
334+
with source_path.open(encoding="utf8") as f:
335+
source_code = f.read()
336+
337+
modified_code, decorator_added = add_async_decorator_to_function(source_code, function_to_optimize, mode)
338+
339+
if decorator_added:
340+
return True, modified_code
341+
else:
342+
return False, None
343+
344+
except Exception as e:
345+
return False, None
346+
347+
325348
def inject_profiling_into_existing_test(
326349
test_path: Path,
327350
call_positions: list[CodePosition],
328351
function_to_optimize: FunctionToOptimize,
329352
tests_project_root: Path,
330353
test_framework: str,
331354
mode: TestingMode = TestingMode.BEHAVIOR,
355+
source_module_path: Path | None = None,
332356
) -> tuple[bool, str | None]:
333357
with test_path.open(encoding="utf8") as f:
334358
test_code = f.read()
@@ -343,11 +367,10 @@ def inject_profiling_into_existing_test(
343367
import_visitor.visit(tree)
344368
func = import_visitor.imported_as
345369

346-
if func.is_async:
347-
modified_code, decorator_added = add_async_decorator_to_function(test_code, func)
348-
if decorator_added:
349-
logger.debug(f"Applied @codeflash_trace_async decorator to async function {func.qualified_name}")
350-
return True, modified_code
370+
if func.is_async and source_module_path and source_module_path.exists():
371+
source_success, instrumented_source = instrument_source_module_with_async_decorators(
372+
source_module_path, func, mode
373+
)
351374

352375
tree = InjectPerfOnly(func, test_module_path, test_framework, call_positions, mode=mode).visit(tree)
353376
new_imports = [
@@ -739,21 +762,28 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
739762

740763

741764
class AsyncDecoratorAdder(cst.CSTTransformer):
742-
"""Transformer that adds @codeflash_trace_async decorator to async function definitions."""
765+
"""Transformer that adds async decorator to async function definitions."""
743766

744-
def __init__(self, function: FunctionToOptimize) -> None:
767+
def __init__(self, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
745768
"""Initialize the transformer.
746769
747770
Args:
748771
----
749772
function: The FunctionToOptimize object representing the target async function.
773+
mode: The testing mode to determine which decorator to apply.
750774
751775
"""
752776
super().__init__()
753777
self.function = function
778+
self.mode = mode
754779
self.qualified_name_parts = function.qualified_name.split(".")
755780
self.context_stack = []
756781
self.added_decorator = False
782+
783+
# Choose decorator based on mode
784+
self.decorator_name = (
785+
"codeflash_behavior_async" if mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
786+
)
757787

758788
def visit_ClassDef(self, node: cst.ClassDef) -> None:
759789
# Track when we enter a class
@@ -781,7 +811,7 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
781811

782812
# Only add the decorator if it's not already there
783813
if not has_decorator:
784-
new_decorator = cst.Decorator(decorator=cst.Name(value="codeflash_trace_async"))
814+
new_decorator = cst.Decorator(decorator=cst.Name(value=self.decorator_name))
785815

786816
# Add our new decorator to the existing decorators
787817
updated_decorators = [new_decorator, *list(updated_node.decorators)]
@@ -795,16 +825,17 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
795825
def _is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Call) -> bool:
796826
"""Check if a decorator matches our target decorator name."""
797827
if isinstance(decorator_node, cst.Name):
798-
return decorator_node.value == "codeflash_trace_async"
828+
return decorator_node.value in {"codeflash_trace_async", "codeflash_behavior_async", "codeflash_performance_async"}
799829
if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
800-
return decorator_node.func.value == "codeflash_trace_async"
830+
return decorator_node.func.value in {"codeflash_trace_async", "codeflash_behavior_async", "codeflash_performance_async"}
801831
return False
802832

803833

804834
class AsyncDecoratorImportAdder(cst.CSTTransformer):
805-
"""Transformer that adds the import for codeflash_trace_async."""
835+
"""Transformer that adds the import for async decorators."""
806836

807-
def __init__(self) -> None:
837+
def __init__(self, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
838+
self.mode = mode
808839
self.has_import = False
809840

810841
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
@@ -819,48 +850,65 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
819850
):
820851
# Handle both ImportAlias and ImportStar
821852
if not isinstance(node.names, cst.ImportStar):
853+
decorator_name = (
854+
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
855+
)
822856
for import_alias in node.names:
823-
if import_alias.name.value == "codeflash_trace_async":
857+
if import_alias.name.value == decorator_name:
824858
self.has_import = True
825859

826860
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
827861
# If the import is already there, don't add it again
828862
if self.has_import:
829863
return updated_node
830864

865+
# Choose import based on mode
866+
decorator_name = (
867+
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
868+
)
869+
831870
# Parse the import statement into a CST node
832-
import_node = cst.parse_statement("from codeflash.code_utils.codeflash_wrap_decorator import codeflash_trace_async")
871+
import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}")
833872

834873
# Add the import to the module's body
835874
return updated_node.with_changes(body=[import_node, *list(updated_node.body)])
836875

837876

838-
def add_async_decorator_to_function(source_code: str, function: FunctionToOptimize) -> tuple[str, bool]:
839-
"""Add @codeflash_trace_async decorator to an async function definition.
877+
def add_async_decorator_to_function(source_code: str, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR) -> tuple[str, bool]:
878+
"""Add async decorator to an async function definition.
840879
841880
Args:
842881
----
843882
source_code: The source code to modify.
844883
function: The FunctionToOptimize object representing the target async function.
884+
mode: The testing mode to determine which decorator to apply.
845885
846886
Returns:
847887
-------
848888
Tuple of (modified_source_code, was_decorator_added).
849889
850890
"""
891+
if not function.is_async:
892+
return source_code, False
893+
851894
try:
852895
module = cst.parse_module(source_code)
853896

854897
# Add the decorator to the function
855-
decorator_transformer = AsyncDecoratorAdder(function)
898+
decorator_transformer = AsyncDecoratorAdder(function, mode)
856899
module = module.visit(decorator_transformer)
857900

858901
# Add the import if decorator was added
859902
if decorator_transformer.added_decorator:
860-
import_transformer = AsyncDecoratorImportAdder()
903+
import_transformer = AsyncDecoratorImportAdder(mode)
861904
module = module.visit(import_transformer)
862905

863906
return isort.code(module.code, float_to_top=True), decorator_transformer.added_decorator
864907
except Exception as e:
865908
logger.exception(f"Error adding async decorator to function {function.qualified_name}: {e}")
866909
return source_code, False
910+
911+
912+
def create_instrumented_source_module_path(source_path: Path, temp_dir: Path) -> Path:
913+
instrumented_filename = f"instrumented_{source_path.name}"
914+
return temp_dir / instrumented_filename

codeflash/discovery/functions_to_optimize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,11 @@ def __str__(self) -> str:
159159

160160
@property
161161
def qualified_name(self) -> str:
162-
return self.function_name if self.parents == [] else f"{self.parents[0].name}.{self.function_name}"
162+
if not self.parents:
163+
return self.function_name
164+
# Join all parent names with dots to handle nested classes properly
165+
parent_path = ".".join(parent.name for parent in self.parents)
166+
return f"{parent_path}.{self.function_name}"
163167

164168
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
165169
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ dependencies = [
4343
"platformdirs>=4.3.7",
4444
"pygls>=1.3.1",
4545
"codeflash-benchmark",
46+
"pytest-asyncio>=1.1.0",
4647
]
4748

4849
[project.urls]

0 commit comments

Comments
 (0)