Skip to content

Commit 3e106cb

Browse files
committed
linting / cleanup
1 parent 14af1a8 commit 3e106cb

File tree

2 files changed

+36
-37
lines changed

2 files changed

+36
-37
lines changed

codeflash/code_utils/codeflash_wrap_decorator.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def extract_test_context_from_frame() -> tuple[str, str | None, str]:
3232
del frame
3333

3434

35-
def codeflash_behavior_async(func):
35+
def codeflash_behavior_async(func: F) -> F:
3636
"""Async decorator for behavior analysis - collects timing data and function inputs/outputs."""
3737
function_name = func.__name__
3838
line_id = f"{func.__name__}_{func.__code__.co_firstlineno}"
@@ -108,7 +108,7 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
108108
return async_wrapper
109109

110110

111-
def codeflash_performance_async(func):
111+
def codeflash_performance_async(func: F) -> F:
112112
"""Async decorator for performance analysis - lightweight timing measurements only."""
113113
function_name = func.__name__
114114
line_id = f"{func.__name__}_{func.__code__.co_firstlineno}"
@@ -161,8 +161,3 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
161161

162162
return async_wrapper
163163

164-
165-
# Convenience function to get the appropriate decorator based on mode
166-
def codeflash_trace_async(func):
167-
"""Default async decorator - uses behavior analysis mode."""
168-
return codeflash_behavior_async(func)

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -323,25 +323,23 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
323323

324324

325325
def instrument_source_module_with_async_decorators(
326-
source_path: Path,
327-
function_to_optimize: FunctionToOptimize,
328-
mode: TestingMode = TestingMode.BEHAVIOR,
326+
source_path: Path, function_to_optimize: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
329327
) -> tuple[bool, str | None]:
330328
if not function_to_optimize.is_async:
331329
return False, None
332-
330+
333331
try:
334332
with source_path.open(encoding="utf8") as f:
335333
source_code = f.read()
336-
334+
337335
modified_code, decorator_added = add_async_decorator_to_function(source_code, function_to_optimize, mode)
338-
336+
339337
if decorator_added:
340338
return True, modified_code
341-
else:
342-
return False, None
343-
344-
except Exception as e:
339+
340+
except Exception:
341+
return False, None
342+
else:
345343
return False, None
346344

347345

@@ -361,7 +359,7 @@ def inject_profiling_into_existing_test(
361359
except SyntaxError:
362360
logger.exception(f"Syntax error in code in file - {test_path}")
363361
return False, None
364-
362+
365363
test_module_path = module_name_from_file_path(test_path, tests_project_root)
366364
import_visitor = FunctionImportedAsVisitor(function_to_optimize)
367365
import_visitor.visit(tree)
@@ -779,7 +777,7 @@ def __init__(self, function: FunctionToOptimize, mode: TestingMode = TestingMode
779777
self.qualified_name_parts = function.qualified_name.split(".")
780778
self.context_stack = []
781779
self.added_decorator = False
782-
780+
783781
# Choose decorator based on mode
784782
self.decorator_name = (
785783
"codeflash_behavior_async" if mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
@@ -798,12 +796,9 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
798796
# Track when we enter a function
799797
self.context_stack.append(node.name.value)
800798

801-
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
799+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
802800
# Check if this is an async function and matches our target
803-
if (
804-
original_node.asynchronous is not None
805-
and self.context_stack == self.qualified_name_parts
806-
):
801+
if original_node.asynchronous is not None and self.context_stack == self.qualified_name_parts:
807802
# Check if the decorator is already present
808803
has_decorator = any(
809804
self._is_target_decorator(decorator.decorator) for decorator in original_node.decorators
@@ -825,9 +820,17 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
825820
def _is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Call) -> bool:
826821
"""Check if a decorator matches our target decorator name."""
827822
if isinstance(decorator_node, cst.Name):
828-
return decorator_node.value in {"codeflash_trace_async", "codeflash_behavior_async", "codeflash_performance_async"}
823+
return decorator_node.value in {
824+
"codeflash_trace_async",
825+
"codeflash_behavior_async",
826+
"codeflash_performance_async",
827+
}
829828
if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
830-
return decorator_node.func.value in {"codeflash_trace_async", "codeflash_behavior_async", "codeflash_performance_async"}
829+
return decorator_node.func.value in {
830+
"codeflash_trace_async",
831+
"codeflash_behavior_async",
832+
"codeflash_performance_async",
833+
}
831834
return False
832835

833836

@@ -847,15 +850,14 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
847850
and node.module.value.value.value == "codeflash"
848851
and node.module.value.attr.value == "code_utils"
849852
and node.module.attr.value == "codeflash_wrap_decorator"
853+
and not isinstance(node.names, cst.ImportStar)
850854
):
851-
# Handle both ImportAlias and ImportStar
852-
if not isinstance(node.names, cst.ImportStar):
853-
decorator_name = (
854-
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
855-
)
856-
for import_alias in node.names:
857-
if import_alias.name.value == decorator_name:
858-
self.has_import = True
855+
decorator_name = (
856+
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
857+
)
858+
for import_alias in node.names:
859+
if import_alias.name.value == decorator_name:
860+
self.has_import = True
859861

860862
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
861863
# If the import is already there, don't add it again
@@ -866,15 +868,17 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
866868
decorator_name = (
867869
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
868870
)
869-
871+
870872
# Parse the import statement into a CST node
871873
import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}")
872874

873875
# Add the import to the module's body
874876
return updated_node.with_changes(body=[import_node, *list(updated_node.body)])
875877

876878

877-
def add_async_decorator_to_function(source_code: str, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR) -> tuple[str, bool]:
879+
def add_async_decorator_to_function(
880+
source_code: str, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
881+
) -> tuple[str, bool]:
878882
"""Add async decorator to an async function definition.
879883
880884
Args:
@@ -890,7 +894,7 @@ def add_async_decorator_to_function(source_code: str, function: FunctionToOptimi
890894
"""
891895
if not function.is_async:
892896
return source_code, False
893-
897+
894898
try:
895899
module = cst.parse_module(source_code)
896900

0 commit comments

Comments
 (0)