@@ -323,25 +323,23 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
323323
324324
325325def instrument_source_module_with_async_decorators (
326- source_path : Path ,
327- function_to_optimize : FunctionToOptimize ,
328- mode : TestingMode = TestingMode .BEHAVIOR ,
326+ source_path : Path , function_to_optimize : FunctionToOptimize , mode : TestingMode = TestingMode .BEHAVIOR
329327) -> tuple [bool , str | None ]:
330328 if not function_to_optimize .is_async :
331329 return False , None
332-
330+
333331 try :
334332 with source_path .open (encoding = "utf8" ) as f :
335333 source_code = f .read ()
336-
334+
337335 modified_code , decorator_added = add_async_decorator_to_function (source_code , function_to_optimize , mode )
338-
336+
339337 if decorator_added :
340338 return True , modified_code
341- else :
342- return False , None
343-
344- except Exception as e :
339+
340+ except Exception :
341+ return False , None
342+ else :
345343 return False , None
346344
347345
@@ -361,7 +359,7 @@ def inject_profiling_into_existing_test(
361359 except SyntaxError :
362360 logger .exception (f"Syntax error in code in file - { test_path } " )
363361 return False , None
364-
362+
365363 test_module_path = module_name_from_file_path (test_path , tests_project_root )
366364 import_visitor = FunctionImportedAsVisitor (function_to_optimize )
367365 import_visitor .visit (tree )
@@ -779,7 +777,7 @@ def __init__(self, function: FunctionToOptimize, mode: TestingMode = TestingMode
779777 self .qualified_name_parts = function .qualified_name .split ("." )
780778 self .context_stack = []
781779 self .added_decorator = False
782-
780+
783781 # Choose decorator based on mode
784782 self .decorator_name = (
785783 "codeflash_behavior_async" if mode == TestingMode .BEHAVIOR else "codeflash_performance_async"
@@ -798,12 +796,9 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
798796 # Track when we enter a function
799797 self .context_stack .append (node .name .value )
800798
801- def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef : # noqa: ARG002
799+ def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef :
802800 # Check if this is an async function and matches our target
803- if (
804- original_node .asynchronous is not None
805- and self .context_stack == self .qualified_name_parts
806- ):
801+ if original_node .asynchronous is not None and self .context_stack == self .qualified_name_parts :
807802 # Check if the decorator is already present
808803 has_decorator = any (
809804 self ._is_target_decorator (decorator .decorator ) for decorator in original_node .decorators
@@ -825,9 +820,17 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
825820 def _is_target_decorator (self , decorator_node : cst .Name | cst .Attribute | cst .Call ) -> bool :
826821 """Check if a decorator matches our target decorator name."""
827822 if isinstance (decorator_node , cst .Name ):
828- return decorator_node .value in {"codeflash_trace_async" , "codeflash_behavior_async" , "codeflash_performance_async" }
823+ return decorator_node .value in {
824+ "codeflash_trace_async" ,
825+ "codeflash_behavior_async" ,
826+ "codeflash_performance_async" ,
827+ }
829828 if isinstance (decorator_node , cst .Call ) and isinstance (decorator_node .func , cst .Name ):
830- return decorator_node .func .value in {"codeflash_trace_async" , "codeflash_behavior_async" , "codeflash_performance_async" }
829+ return decorator_node .func .value in {
830+ "codeflash_trace_async" ,
831+ "codeflash_behavior_async" ,
832+ "codeflash_performance_async" ,
833+ }
831834 return False
832835
833836
@@ -847,15 +850,14 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
847850 and node .module .value .value .value == "codeflash"
848851 and node .module .value .attr .value == "code_utils"
849852 and node .module .attr .value == "codeflash_wrap_decorator"
853+ and not isinstance (node .names , cst .ImportStar )
850854 ):
851- # Handle both ImportAlias and ImportStar
852- if not isinstance (node .names , cst .ImportStar ):
853- decorator_name = (
854- "codeflash_behavior_async" if self .mode == TestingMode .BEHAVIOR else "codeflash_performance_async"
855- )
856- for import_alias in node .names :
857- if import_alias .name .value == decorator_name :
858- self .has_import = True
855+ decorator_name = (
856+ "codeflash_behavior_async" if self .mode == TestingMode .BEHAVIOR else "codeflash_performance_async"
857+ )
858+ for import_alias in node .names :
859+ if import_alias .name .value == decorator_name :
860+ self .has_import = True
859861
860862 def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module : # noqa: ARG002
861863 # If the import is already there, don't add it again
@@ -866,15 +868,17 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
866868 decorator_name = (
867869 "codeflash_behavior_async" if self .mode == TestingMode .BEHAVIOR else "codeflash_performance_async"
868870 )
869-
871+
870872 # Parse the import statement into a CST node
871873 import_node = cst .parse_statement (f"from codeflash.code_utils.codeflash_wrap_decorator import { decorator_name } " )
872874
873875 # Add the import to the module's body
874876 return updated_node .with_changes (body = [import_node , * list (updated_node .body )])
875877
876878
877- def add_async_decorator_to_function (source_code : str , function : FunctionToOptimize , mode : TestingMode = TestingMode .BEHAVIOR ) -> tuple [str , bool ]:
879+ def add_async_decorator_to_function (
880+ source_code : str , function : FunctionToOptimize , mode : TestingMode = TestingMode .BEHAVIOR
881+ ) -> tuple [str , bool ]:
878882 """Add async decorator to an async function definition.
879883
880884 Args:
@@ -890,7 +894,7 @@ def add_async_decorator_to_function(source_code: str, function: FunctionToOptimi
890894 """
891895 if not function .is_async :
892896 return source_code , False
893-
897+
894898 try :
895899 module = cst .parse_module (source_code )
896900
0 commit comments