@@ -291,6 +291,139 @@ def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None =
291291 return node
292292
293293
294+ class AsyncCallInstrumenter (ast .NodeTransformer ):
295+ def __init__ (
296+ self ,
297+ function : FunctionToOptimize ,
298+ module_path : str ,
299+ test_framework : str ,
300+ call_positions : list [CodePosition ],
301+ mode : TestingMode = TestingMode .BEHAVIOR ,
302+ ) -> None :
303+ self .mode = mode
304+ self .function_object = function
305+ self .class_name = None
306+ self .only_function_name = function .function_name
307+ self .module_path = module_path
308+ self .test_framework = test_framework
309+ self .call_positions = call_positions
310+ self .did_instrument = False
311+ # Track function call count per test function
312+ self .async_call_counter : dict [str , int ] = {}
313+ if len (function .parents ) == 1 and function .parents [0 ].type == "ClassDef" :
314+ self .class_name = function .top_level_parent_name
315+
316+ def visit_ClassDef (self , node : ast .ClassDef ) -> ast .ClassDef :
317+ # Add timeout decorator for unittest test classes if needed
318+ if self .test_framework == "unittest" :
319+ for item in node .body :
320+ if (
321+ isinstance (item , ast .FunctionDef )
322+ and item .name .startswith ("test_" )
323+ and not any (
324+ isinstance (d , ast .Call )
325+ and isinstance (d .func , ast .Name )
326+ and d .func .id == "timeout_decorator.timeout"
327+ for d in item .decorator_list
328+ )
329+ ):
330+ timeout_decorator = ast .Call (
331+ func = ast .Name (id = "timeout_decorator.timeout" , ctx = ast .Load ()),
332+ args = [ast .Constant (value = 15 )],
333+ keywords = [],
334+ )
335+ item .decorator_list .append (timeout_decorator )
336+ return self .generic_visit (node )
337+
338+ def visit_AsyncFunctionDef (self , node : ast .AsyncFunctionDef ) -> ast .AsyncFunctionDef :
339+ if not node .name .startswith ("test_" ):
340+ return node
341+
342+ return self ._process_test_function (node )
343+
344+ def visit_FunctionDef (self , node : ast .FunctionDef ) -> ast .FunctionDef :
345+ # Only process test functions
346+ if not node .name .startswith ("test_" ):
347+ return node
348+
349+ return self ._process_test_function (node )
350+
351+ def _process_test_function (
352+ self , node : ast .AsyncFunctionDef | ast .FunctionDef
353+ ) -> ast .AsyncFunctionDef | ast .FunctionDef :
354+ if self .test_framework == "unittest" and not any (
355+ isinstance (d , ast .Call ) and isinstance (d .func , ast .Name ) and d .func .id == "timeout_decorator.timeout"
356+ for d in node .decorator_list
357+ ):
358+ timeout_decorator = ast .Call (
359+ func = ast .Name (id = "timeout_decorator.timeout" , ctx = ast .Load ()),
360+ args = [ast .Constant (value = 15 )],
361+ keywords = [],
362+ )
363+ node .decorator_list .append (timeout_decorator )
364+
365+ # Initialize counter for this test function
366+ if node .name not in self .async_call_counter :
367+ self .async_call_counter [node .name ] = 0
368+
369+ new_body = []
370+
371+ for i , stmt in enumerate (node .body ):
372+ transformed_stmt , added_env_assignment = self ._instrument_statement (stmt , node .name )
373+
374+ if added_env_assignment :
375+ current_call_index = self .async_call_counter [node .name ]
376+ self .async_call_counter [node .name ] += 1
377+
378+ env_assignment = ast .Assign (
379+ targets = [
380+ ast .Subscript (
381+ value = ast .Attribute (
382+ value = ast .Name (id = "os" , ctx = ast .Load ()), attr = "environ" , ctx = ast .Load ()
383+ ),
384+ slice = ast .Constant (value = "CODEFLASH_CURRENT_LINE_ID" ),
385+ ctx = ast .Store (),
386+ )
387+ ],
388+ value = ast .Constant (value = f"{ current_call_index } " ),
389+ lineno = stmt .lineno if hasattr (stmt , "lineno" ) else 1 ,
390+ )
391+ new_body .append (env_assignment )
392+ self .did_instrument = True
393+
394+ new_body .append (transformed_stmt )
395+
396+ node .body = new_body
397+ return node
398+
399+ def _instrument_statement (self , stmt : ast .stmt , node_name : str ) -> tuple [ast .stmt , bool ]:
400+ for node in ast .walk (stmt ):
401+ if (
402+ isinstance (node , ast .Await )
403+ and isinstance (node .value , ast .Call )
404+ and self ._is_target_call (node .value )
405+ and self ._call_in_positions (node .value )
406+ ):
407+ # Check if this call is in one of our target positions
408+ return stmt , True # Return original statement but signal we added env var
409+
410+ return stmt , False
411+
412+ def _is_target_call (self , call_node : ast .Call ) -> bool :
413+ """Check if this call node is calling our target async function."""
414+ if isinstance (call_node .func , ast .Name ):
415+ return call_node .func .id == self .function_object .function_name
416+ if isinstance (call_node .func , ast .Attribute ):
417+ return call_node .func .attr == self .function_object .function_name
418+ return False
419+
420+ def _call_in_positions (self , call_node : ast .Call ) -> bool :
421+ if not hasattr (call_node , "lineno" ) or not hasattr (call_node , "col_offset" ):
422+ return False
423+
424+ return node_in_call_position (call_node , self .call_positions )
425+
426+
294427class FunctionImportedAsVisitor (ast .NodeVisitor ):
295428 """Checks if a function has been imported as an alias. We only care about the alias then.
296429
@@ -352,6 +485,44 @@ def instrument_source_module_with_async_decorators(
352485 return False , None
353486
354487
488+ def inject_async_profiling_into_existing_test (
489+ test_path : Path ,
490+ call_positions : list [CodePosition ],
491+ function_to_optimize : FunctionToOptimize ,
492+ tests_project_root : Path ,
493+ test_framework : str ,
494+ mode : TestingMode = TestingMode .BEHAVIOR ,
495+ ) -> tuple [bool , str | None ]:
496+ """Inject profiling for async function calls by setting environment variables before each call."""
497+ with test_path .open (encoding = "utf8" ) as f :
498+ test_code = f .read ()
499+
500+ try :
501+ tree = ast .parse (test_code )
502+ except SyntaxError :
503+ logger .exception (f"Syntax error in code in file - { test_path } " )
504+ return False , None
505+
506+ test_module_path = module_name_from_file_path (test_path , tests_project_root )
507+ import_visitor = FunctionImportedAsVisitor (function_to_optimize )
508+ import_visitor .visit (tree )
509+ func = import_visitor .imported_as
510+
511+ async_instrumenter = AsyncCallInstrumenter (func , test_module_path , test_framework , call_positions , mode = mode )
512+ tree = async_instrumenter .visit (tree )
513+
514+ if not async_instrumenter .did_instrument :
515+ return False , None
516+
517+ # Add necessary imports
518+ new_imports = [ast .Import (names = [ast .alias (name = "os" )])]
519+ if test_framework == "unittest" :
520+ new_imports .append (ast .Import (names = [ast .alias (name = "timeout_decorator" )]))
521+
522+ tree .body = [* new_imports , * tree .body ]
523+ return True , isort .code (ast .unparse (tree ), float_to_top = True )
524+
525+
355526def inject_profiling_into_existing_test (
356527 test_path : Path ,
357528 call_positions : list [CodePosition ],
@@ -361,7 +532,9 @@ def inject_profiling_into_existing_test(
361532 mode : TestingMode = TestingMode .BEHAVIOR ,
362533) -> tuple [bool , str | None ]:
363534 if function_to_optimize .is_async :
364- return False , None
535+ return inject_async_profiling_into_existing_test (
536+ test_path , call_positions , function_to_optimize , tests_project_root , test_framework , mode
537+ )
365538
366539 with test_path .open (encoding = "utf8" ) as f :
367540 test_code = f .read ()
0 commit comments