@@ -56,6 +56,7 @@ def __init__(
5656 test_framework : str ,
5757 call_positions : list [CodePosition ],
5858 mode : TestingMode = TestingMode .BEHAVIOR ,
59+ is_async : bool = False ,
5960 ) -> None :
6061 self .mode : TestingMode = mode
6162 self .function_object = function
@@ -64,6 +65,7 @@ def __init__(
6465 self .module_path = module_path
6566 self .test_framework = test_framework
6667 self .call_positions = call_positions
68+ self .is_async = is_async
6769 if len (function .parents ) == 1 and function .parents [0 ].type == "ClassDef" :
6870 self .class_name = function .top_level_parent_name
6971
@@ -328,6 +330,7 @@ def inject_profiling_into_existing_test(
328330 tests_project_root : Path ,
329331 test_framework : str ,
330332 mode : TestingMode = TestingMode .BEHAVIOR ,
333+ is_async : bool = False ,
331334) -> tuple [bool , str | None ]:
332335 with test_path .open (encoding = "utf8" ) as f :
333336 test_code = f .read ()
@@ -342,7 +345,9 @@ def inject_profiling_into_existing_test(
342345 import_visitor .visit (tree )
343346 func = import_visitor .imported_as
344347
345- tree = InjectPerfOnly (func , test_module_path , test_framework , call_positions , mode = mode ).visit (tree )
348+ tree = InjectPerfOnly (func , test_module_path , test_framework , call_positions , mode = mode , is_async = is_async ).visit (
349+ tree
350+ )
346351 new_imports = [
347352 ast .Import (names = [ast .alias (name = "time" )]),
348353 ast .Import (names = [ast .alias (name = "gc" )]),
@@ -354,11 +359,11 @@ def inject_profiling_into_existing_test(
354359 )
355360 if test_framework == "unittest" :
356361 new_imports .append (ast .Import (names = [ast .alias (name = "timeout_decorator" )]))
357- tree .body = [* new_imports , create_wrapper_function (mode ), * tree .body ]
362+ tree .body = [* new_imports , create_wrapper_function (mode , is_async ), * tree .body ]
358363 return True , isort .code (ast .unparse (tree ), float_to_top = True )
359364
360365
361- def create_wrapper_function (mode : TestingMode = TestingMode .BEHAVIOR ) -> ast .FunctionDef :
366+ def create_wrapper_function (mode : TestingMode = TestingMode .BEHAVIOR , is_async : bool = False ) -> ast .FunctionDef :
362367 lineno = 1
363368 wrapper_body : list [ast .stmt ] = [
364369 ast .Assign (
@@ -536,7 +541,15 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
536541 ),
537542 ast .Assign (
538543 targets = [ast .Name (id = "return_value" , ctx = ast .Store ())],
539- value = ast .Call (
544+ value = ast .Await (
545+ value = ast .Call (
546+ func = ast .Name (id = "wrapped" , ctx = ast .Load ()),
547+ args = [ast .Starred (value = ast .Name (id = "args" , ctx = ast .Load ()), ctx = ast .Load ())],
548+ keywords = [ast .keyword (arg = None , value = ast .Name (id = "kwargs" , ctx = ast .Load ()))],
549+ )
550+ )
551+ if is_async
552+ else ast .Call (
540553 func = ast .Name (id = "wrapped" , ctx = ast .Load ()),
541554 args = [ast .Starred (value = ast .Name (id = "args" , ctx = ast .Load ()), ctx = ast .Load ())],
542555 keywords = [ast .keyword (arg = None , value = ast .Name (id = "kwargs" , ctx = ast .Load ()))],
@@ -703,7 +716,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
703716 ),
704717 ast .Return (value = ast .Name (id = "return_value" , ctx = ast .Load ()), lineno = lineno + 19 ),
705718 ]
706- return ast .FunctionDef (
719+ func_def = ast .FunctionDef (
707720 name = "codeflash_wrap" ,
708721 args = ast .arguments (
709722 args = [
@@ -729,3 +742,13 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
729742 decorator_list = [],
730743 returns = None ,
731744 )
745+ if is_async :
746+ return ast .AsyncFunctionDef (
747+ name = "codeflash_wrap" ,
748+ args = func_def .args ,
749+ body = func_def .body ,
750+ lineno = func_def .lineno ,
751+ decorator_list = func_def .decorator_list ,
752+ returns = func_def .returns ,
753+ )
754+ return func_def
0 commit comments