@@ -322,13 +322,37 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
322322 )
323323
324324
325+ def instrument_source_module_with_async_decorators (
326+ source_path : Path ,
327+ function_to_optimize : FunctionToOptimize ,
328+ mode : TestingMode = TestingMode .BEHAVIOR ,
329+ ) -> tuple [bool , str | None ]:
330+ if not function_to_optimize .is_async :
331+ return False , None
332+
333+ try :
334+ with source_path .open (encoding = "utf8" ) as f :
335+ source_code = f .read ()
336+
337+ modified_code , decorator_added = add_async_decorator_to_function (source_code , function_to_optimize , mode )
338+
339+ if decorator_added :
340+ return True , modified_code
341+ else :
342+ return False , None
343+
344+ except Exception as e :
345+ return False , None
346+
347+
325348def inject_profiling_into_existing_test (
326349 test_path : Path ,
327350 call_positions : list [CodePosition ],
328351 function_to_optimize : FunctionToOptimize ,
329352 tests_project_root : Path ,
330353 test_framework : str ,
331354 mode : TestingMode = TestingMode .BEHAVIOR ,
355+ source_module_path : Path | None = None ,
332356) -> tuple [bool , str | None ]:
333357 with test_path .open (encoding = "utf8" ) as f :
334358 test_code = f .read ()
@@ -343,11 +367,10 @@ def inject_profiling_into_existing_test(
343367 import_visitor .visit (tree )
344368 func = import_visitor .imported_as
345369
346- if func .is_async :
347- modified_code , decorator_added = add_async_decorator_to_function (test_code , func )
348- if decorator_added :
349- logger .debug (f"Applied @codeflash_trace_async decorator to async function { func .qualified_name } " )
350- return True , modified_code
370+ if func .is_async and source_module_path and source_module_path .exists ():
371+ source_success , instrumented_source = instrument_source_module_with_async_decorators (
372+ source_module_path , func , mode
373+ )
351374
352375 tree = InjectPerfOnly (func , test_module_path , test_framework , call_positions , mode = mode ).visit (tree )
353376 new_imports = [
@@ -739,21 +762,28 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
739762
740763
741764class AsyncDecoratorAdder (cst .CSTTransformer ):
742- """Transformer that adds @codeflash_trace_async decorator to async function definitions."""
765+ """Transformer that adds async decorator to async function definitions."""
743766
744- def __init__ (self , function : FunctionToOptimize ) -> None :
767+ def __init__ (self , function : FunctionToOptimize , mode : TestingMode = TestingMode . BEHAVIOR ) -> None :
745768 """Initialize the transformer.
746769
747770 Args:
748771 ----
749772 function: The FunctionToOptimize object representing the target async function.
773+ mode: The testing mode to determine which decorator to apply.
750774
751775 """
752776 super ().__init__ ()
753777 self .function = function
778+ self .mode = mode
754779 self .qualified_name_parts = function .qualified_name .split ("." )
755780 self .context_stack = []
756781 self .added_decorator = False
782+
783+ # Choose decorator based on mode
784+ self .decorator_name = (
785+ "codeflash_behavior_async" if mode == TestingMode .BEHAVIOR else "codeflash_performance_async"
786+ )
757787
758788 def visit_ClassDef (self , node : cst .ClassDef ) -> None :
759789 # Track when we enter a class
@@ -781,7 +811,7 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
781811
782812 # Only add the decorator if it's not already there
783813 if not has_decorator :
784- new_decorator = cst .Decorator (decorator = cst .Name (value = "codeflash_trace_async" ))
814+ new_decorator = cst .Decorator (decorator = cst .Name (value = self . decorator_name ))
785815
786816 # Add our new decorator to the existing decorators
787817 updated_decorators = [new_decorator , * list (updated_node .decorators )]
@@ -795,16 +825,17 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
795825 def _is_target_decorator (self , decorator_node : cst .Name | cst .Attribute | cst .Call ) -> bool :
796826 """Check if a decorator matches our target decorator name."""
797827 if isinstance (decorator_node , cst .Name ):
798- return decorator_node .value == "codeflash_trace_async"
828+ return decorator_node .value in { "codeflash_trace_async" , "codeflash_behavior_async" , "codeflash_performance_async" }
799829 if isinstance (decorator_node , cst .Call ) and isinstance (decorator_node .func , cst .Name ):
800- return decorator_node .func .value == "codeflash_trace_async"
830+ return decorator_node .func .value in { "codeflash_trace_async" , "codeflash_behavior_async" , "codeflash_performance_async" }
801831 return False
802832
803833
804834class AsyncDecoratorImportAdder (cst .CSTTransformer ):
805- """Transformer that adds the import for codeflash_trace_async ."""
835+ """Transformer that adds the import for async decorators ."""
806836
807- def __init__ (self ) -> None :
837+ def __init__ (self , mode : TestingMode = TestingMode .BEHAVIOR ) -> None :
838+ self .mode = mode
808839 self .has_import = False
809840
810841 def visit_ImportFrom (self , node : cst .ImportFrom ) -> None :
@@ -819,48 +850,65 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
819850 ):
820851 # Handle both ImportAlias and ImportStar
821852 if not isinstance (node .names , cst .ImportStar ):
853+ decorator_name = (
854+ "codeflash_behavior_async" if self .mode == TestingMode .BEHAVIOR else "codeflash_performance_async"
855+ )
822856 for import_alias in node .names :
823- if import_alias .name .value == "codeflash_trace_async" :
857+ if import_alias .name .value == decorator_name :
824858 self .has_import = True
825859
826860 def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module : # noqa: ARG002
827861 # If the import is already there, don't add it again
828862 if self .has_import :
829863 return updated_node
830864
865+ # Choose import based on mode
866+ decorator_name = (
867+ "codeflash_behavior_async" if self .mode == TestingMode .BEHAVIOR else "codeflash_performance_async"
868+ )
869+
831870 # Parse the import statement into a CST node
832- import_node = cst .parse_statement ("from codeflash.code_utils.codeflash_wrap_decorator import codeflash_trace_async " )
871+ import_node = cst .parse_statement (f "from codeflash.code_utils.codeflash_wrap_decorator import { decorator_name } " )
833872
834873 # Add the import to the module's body
835874 return updated_node .with_changes (body = [import_node , * list (updated_node .body )])
836875
837876
838- def add_async_decorator_to_function (source_code : str , function : FunctionToOptimize ) -> tuple [str , bool ]:
839- """Add @codeflash_trace_async decorator to an async function definition.
877+ def add_async_decorator_to_function (source_code : str , function : FunctionToOptimize , mode : TestingMode = TestingMode . BEHAVIOR ) -> tuple [str , bool ]:
878+ """Add async decorator to an async function definition.
840879
841880 Args:
842881 ----
843882 source_code: The source code to modify.
844883 function: The FunctionToOptimize object representing the target async function.
884+ mode: The testing mode to determine which decorator to apply.
845885
846886 Returns:
847887 -------
848888 Tuple of (modified_source_code, was_decorator_added).
849889
850890 """
891+ if not function .is_async :
892+ return source_code , False
893+
851894 try :
852895 module = cst .parse_module (source_code )
853896
854897 # Add the decorator to the function
855- decorator_transformer = AsyncDecoratorAdder (function )
898+ decorator_transformer = AsyncDecoratorAdder (function , mode )
856899 module = module .visit (decorator_transformer )
857900
858901 # Add the import if decorator was added
859902 if decorator_transformer .added_decorator :
860- import_transformer = AsyncDecoratorImportAdder ()
903+ import_transformer = AsyncDecoratorImportAdder (mode )
861904 module = module .visit (import_transformer )
862905
863906 return isort .code (module .code , float_to_top = True ), decorator_transformer .added_decorator
864907 except Exception as e :
865908 logger .exception (f"Error adding async decorator to function { function .qualified_name } : { e } " )
866909 return source_code , False
910+
911+
912+ def create_instrumented_source_module_path (source_path : Path , temp_dir : Path ) -> Path :
913+ instrumented_filename = f"instrumented_{ source_path .name } "
914+ return temp_dir / instrumented_filename
0 commit comments