@@ -48,6 +48,39 @@ def is_argument_name(name: str, arguments_node: ast.arguments) -> bool:
4848 )
4949
5050
51+ class AsyncIOGatherRemover (ast .NodeTransformer ):
52+ def _contains_asyncio_gather (self , node : ast .AST ) -> bool :
53+ """Check if a node contains asyncio.gather calls."""
54+ for child_node in ast .walk (node ):
55+ if (
56+ isinstance (child_node , ast .Call )
57+ and isinstance (child_node .func , ast .Attribute )
58+ and isinstance (child_node .func .value , ast .Name )
59+ and child_node .func .value .id == "asyncio"
60+ and child_node .func .attr == "gather"
61+ ):
62+ return True
63+
64+ if (
65+ isinstance (child_node , ast .Call )
66+ and isinstance (child_node .func , ast .Name )
67+ and child_node .func .id == "gather"
68+ ):
69+ return True
70+
71+ return False
72+
73+ def visit_FunctionDef (self , node : ast .FunctionDef ) -> ast .FunctionDef | None :
74+ if node .name .startswith ("test_" ) and self ._contains_asyncio_gather (node ):
75+ return None
76+ return self .generic_visit (node )
77+
78+ def visit_AsyncFunctionDef (self , node : ast .AsyncFunctionDef ) -> ast .AsyncFunctionDef | None :
79+ if node .name .startswith ("test_" ) and self ._contains_asyncio_gather (node ):
80+ return None
81+ return self .generic_visit (node )
82+
83+
5184class InjectPerfOnly (ast .NodeTransformer ):
5285 def __init__ (
5386 self ,
@@ -397,6 +430,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
397430 file_path = self .function .file_path ,
398431 starting_line = self .function .starting_line ,
399432 ending_line = self .function .ending_line ,
433+ is_async = self .function .is_async ,
400434 )
401435 else :
402436 self .imported_as = FunctionToOptimize (
@@ -405,6 +439,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
405439 file_path = self .function .file_path ,
406440 starting_line = self .function .starting_line ,
407441 ending_line = self .function .ending_line ,
442+ is_async = self .function .is_async ,
408443 )
409444
410445
@@ -415,7 +450,6 @@ def inject_profiling_into_existing_test(
415450 tests_project_root : Path ,
416451 test_framework : str ,
417452 mode : TestingMode = TestingMode .BEHAVIOR ,
418- is_async : bool = False ,
419453) -> tuple [bool , str | None ]:
420454 with test_path .open (encoding = "utf8" ) as f :
421455 test_code = f .read ()
@@ -430,6 +464,13 @@ def inject_profiling_into_existing_test(
430464 import_visitor .visit (tree )
431465 func = import_visitor .imported_as
432466
467+ is_async = function_to_optimize .is_async
468+ logger .debug (f"Using async status from discovery phase for { function_to_optimize .function_name } : { is_async } " )
469+
470+ if is_async :
471+ asyncio_gather_remover = AsyncIOGatherRemover ()
472+ tree = asyncio_gather_remover .visit (tree )
473+
433474 tree = InjectPerfOnly (func , test_module_path , test_framework , call_positions , mode = mode , is_async = is_async ).visit (
434475 tree
435476 )
@@ -444,11 +485,15 @@ def inject_profiling_into_existing_test(
444485 )
445486 if test_framework == "unittest" :
446487 new_imports .append (ast .Import (names = [ast .alias (name = "timeout_decorator" )]))
488+ if is_async :
489+ new_imports .append (ast .Import (names = [ast .alias (name = "inspect" )]))
447490 tree .body = [* new_imports , create_wrapper_function (mode , is_async ), * tree .body ]
448491 return True , isort .code (ast .unparse (tree ), float_to_top = True )
449492
450493
451- def create_wrapper_function (mode : TestingMode = TestingMode .BEHAVIOR , is_async : bool = False ) -> ast .FunctionDef :
494+ def create_wrapper_function (
495+ mode : TestingMode = TestingMode .BEHAVIOR , is_async : bool = False
496+ ) -> ast .FunctionDef | ast .AsyncFunctionDef :
452497 lineno = 1
453498 wrapper_body : list [ast .stmt ] = [
454499 ast .Assign (
@@ -624,22 +669,70 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR, is_async:
624669 ),
625670 lineno = lineno + 11 ,
626671 ),
627- ast .Assign (
628- targets = [ast .Name (id = "return_value" , ctx = ast .Store ())],
629- value = ast .Await (
630- value = ast .Call (
631- func = ast .Name (id = "wrapped" , ctx = ast .Load ()),
632- args = [ast .Starred (value = ast .Name (id = "args" , ctx = ast .Load ()), ctx = ast .Load ())],
633- keywords = [ast .keyword (arg = None , value = ast .Name (id = "kwargs" , ctx = ast .Load ()))],
634- )
635- )
672+ # For async wrappers
673+ # Call the wrapped function first, then check if result is awaitable before awaiting.
674+ # This handles mixed scenarios where async tests might call both sync and async functions.
675+ * (
676+ [
677+ ast .Assign (
678+ targets = [ast .Name (id = "ret" , ctx = ast .Store ())],
679+ value = ast .Call (
680+ func = ast .Name (id = "wrapped" , ctx = ast .Load ()),
681+ args = [ast .Starred (value = ast .Name (id = "args" , ctx = ast .Load ()), ctx = ast .Load ())],
682+ keywords = [ast .keyword (arg = None , value = ast .Name (id = "kwargs" , ctx = ast .Load ()))],
683+ ),
684+ lineno = lineno + 12 ,
685+ ),
686+ ast .If (
687+ test = ast .Call (
688+ func = ast .Attribute (
689+ value = ast .Name (id = "inspect" , ctx = ast .Load ()), attr = "isawaitable" , ctx = ast .Load ()
690+ ),
691+ args = [ast .Name (id = "ret" , ctx = ast .Load ())],
692+ keywords = [],
693+ ),
694+ body = [
695+ ast .Assign (
696+ targets = [ast .Name (id = "counter" , ctx = ast .Store ())],
697+ value = ast .Call (
698+ func = ast .Attribute (
699+ value = ast .Name (id = "time" , ctx = ast .Load ()),
700+ attr = "perf_counter_ns" ,
701+ ctx = ast .Load (),
702+ ),
703+ args = [],
704+ keywords = [],
705+ ),
706+ lineno = lineno + 14 ,
707+ ),
708+ ast .Assign (
709+ targets = [ast .Name (id = "return_value" , ctx = ast .Store ())],
710+ value = ast .Await (value = ast .Name (id = "ret" , ctx = ast .Load ())),
711+ lineno = lineno + 15 ,
712+ ),
713+ ],
714+ orelse = [
715+ ast .Assign (
716+ targets = [ast .Name (id = "return_value" , ctx = ast .Store ())],
717+ value = ast .Name (id = "ret" , ctx = ast .Load ()),
718+ lineno = lineno + 16 ,
719+ )
720+ ],
721+ lineno = lineno + 13 ,
722+ ),
723+ ]
636724 if is_async
637- else ast .Call (
638- func = ast .Name (id = "wrapped" , ctx = ast .Load ()),
639- args = [ast .Starred (value = ast .Name (id = "args" , ctx = ast .Load ()), ctx = ast .Load ())],
640- keywords = [ast .keyword (arg = None , value = ast .Name (id = "kwargs" , ctx = ast .Load ()))],
641- ),
642- lineno = lineno + 12 ,
725+ else [
726+ ast .Assign (
727+ targets = [ast .Name (id = "return_value" , ctx = ast .Store ())],
728+ value = ast .Call (
729+ func = ast .Name (id = "wrapped" , ctx = ast .Load ()),
730+ args = [ast .Starred (value = ast .Name (id = "args" , ctx = ast .Load ()), ctx = ast .Load ())],
731+ keywords = [ast .keyword (arg = None , value = ast .Name (id = "kwargs" , ctx = ast .Load ()))],
732+ ),
733+ lineno = lineno + 12 ,
734+ )
735+ ]
643736 ),
644737 ast .Assign (
645738 targets = [ast .Name (id = "codeflash_duration" , ctx = ast .Store ())],
0 commit comments