Skip to content

Commit d1b40d0

Browse files
committed
check.
1 parent a7a06ea commit d1b40d0

File tree

3 files changed

+367
-3
lines changed

3 files changed

+367
-3
lines changed

codeflash/code_utils/codeflash_wrap_decorator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def codeflash_behavior_async(func: F) -> F:
4848
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
4949
loop = asyncio.get_running_loop()
5050
function_name = func.__name__
51-
line_id = f"{func.__name__}_{func.__code__.co_firstlineno}"
51+
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
5252
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
5353
test_module_name, test_class_name, test_name = extract_test_context_from_env()
5454

@@ -124,7 +124,7 @@ def codeflash_performance_async(func: F) -> F:
124124
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
125125
loop = asyncio.get_running_loop()
126126
function_name = func.__name__
127-
line_id = f"{func.__name__}_{func.__code__.co_firstlineno}"
127+
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
128128
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
129129

130130
test_module_name, test_class_name, test_name = extract_test_context_from_env()

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 174 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,139 @@ def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None =
291291
return node
292292

293293

294+
class AsyncCallInstrumenter(ast.NodeTransformer):
295+
def __init__(
296+
self,
297+
function: FunctionToOptimize,
298+
module_path: str,
299+
test_framework: str,
300+
call_positions: list[CodePosition],
301+
mode: TestingMode = TestingMode.BEHAVIOR,
302+
) -> None:
303+
self.mode = mode
304+
self.function_object = function
305+
self.class_name = None
306+
self.only_function_name = function.function_name
307+
self.module_path = module_path
308+
self.test_framework = test_framework
309+
self.call_positions = call_positions
310+
self.did_instrument = False
311+
# Track function call count per test function
312+
self.async_call_counter: dict[str, int] = {}
313+
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
314+
self.class_name = function.top_level_parent_name
315+
316+
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
317+
# Add timeout decorator for unittest test classes if needed
318+
if self.test_framework == "unittest":
319+
for item in node.body:
320+
if (
321+
isinstance(item, ast.FunctionDef)
322+
and item.name.startswith("test_")
323+
and not any(
324+
isinstance(d, ast.Call)
325+
and isinstance(d.func, ast.Name)
326+
and d.func.id == "timeout_decorator.timeout"
327+
for d in item.decorator_list
328+
)
329+
):
330+
timeout_decorator = ast.Call(
331+
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
332+
args=[ast.Constant(value=15)],
333+
keywords=[],
334+
)
335+
item.decorator_list.append(timeout_decorator)
336+
return self.generic_visit(node)
337+
338+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
339+
if not node.name.startswith("test_"):
340+
return node
341+
342+
return self._process_test_function(node)
343+
344+
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
345+
# Only process test functions
346+
if not node.name.startswith("test_"):
347+
return node
348+
349+
return self._process_test_function(node)
350+
351+
def _process_test_function(
352+
self, node: ast.AsyncFunctionDef | ast.FunctionDef
353+
) -> ast.AsyncFunctionDef | ast.FunctionDef:
354+
if self.test_framework == "unittest" and not any(
355+
isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "timeout_decorator.timeout"
356+
for d in node.decorator_list
357+
):
358+
timeout_decorator = ast.Call(
359+
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
360+
args=[ast.Constant(value=15)],
361+
keywords=[],
362+
)
363+
node.decorator_list.append(timeout_decorator)
364+
365+
# Initialize counter for this test function
366+
if node.name not in self.async_call_counter:
367+
self.async_call_counter[node.name] = 0
368+
369+
new_body = []
370+
371+
for i, stmt in enumerate(node.body):
372+
transformed_stmt, added_env_assignment = self._instrument_statement(stmt, node.name)
373+
374+
if added_env_assignment:
375+
current_call_index = self.async_call_counter[node.name]
376+
self.async_call_counter[node.name] += 1
377+
378+
env_assignment = ast.Assign(
379+
targets=[
380+
ast.Subscript(
381+
value=ast.Attribute(
382+
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
383+
),
384+
slice=ast.Constant(value="CODEFLASH_CURRENT_LINE_ID"),
385+
ctx=ast.Store(),
386+
)
387+
],
388+
value=ast.Constant(value=f"{current_call_index}"),
389+
lineno=stmt.lineno if hasattr(stmt, "lineno") else 1,
390+
)
391+
new_body.append(env_assignment)
392+
self.did_instrument = True
393+
394+
new_body.append(transformed_stmt)
395+
396+
node.body = new_body
397+
return node
398+
399+
def _instrument_statement(self, stmt: ast.stmt, node_name: str) -> tuple[ast.stmt, bool]:
400+
for node in ast.walk(stmt):
401+
if (
402+
isinstance(node, ast.Await)
403+
and isinstance(node.value, ast.Call)
404+
and self._is_target_call(node.value)
405+
and self._call_in_positions(node.value)
406+
):
407+
# Check if this call is in one of our target positions
408+
return stmt, True # Return original statement but signal we added env var
409+
410+
return stmt, False
411+
412+
def _is_target_call(self, call_node: ast.Call) -> bool:
413+
"""Check if this call node is calling our target async function."""
414+
if isinstance(call_node.func, ast.Name):
415+
return call_node.func.id == self.function_object.function_name
416+
if isinstance(call_node.func, ast.Attribute):
417+
return call_node.func.attr == self.function_object.function_name
418+
return False
419+
420+
def _call_in_positions(self, call_node: ast.Call) -> bool:
421+
if not hasattr(call_node, "lineno") or not hasattr(call_node, "col_offset"):
422+
return False
423+
424+
return node_in_call_position(call_node, self.call_positions)
425+
426+
294427
class FunctionImportedAsVisitor(ast.NodeVisitor):
295428
"""Checks if a function has been imported as an alias. We only care about the alias then.
296429
@@ -352,6 +485,44 @@ def instrument_source_module_with_async_decorators(
352485
return False, None
353486

354487

488+
def inject_async_profiling_into_existing_test(
489+
test_path: Path,
490+
call_positions: list[CodePosition],
491+
function_to_optimize: FunctionToOptimize,
492+
tests_project_root: Path,
493+
test_framework: str,
494+
mode: TestingMode = TestingMode.BEHAVIOR,
495+
) -> tuple[bool, str | None]:
496+
"""Inject profiling for async function calls by setting environment variables before each call."""
497+
with test_path.open(encoding="utf8") as f:
498+
test_code = f.read()
499+
500+
try:
501+
tree = ast.parse(test_code)
502+
except SyntaxError:
503+
logger.exception(f"Syntax error in code in file - {test_path}")
504+
return False, None
505+
506+
test_module_path = module_name_from_file_path(test_path, tests_project_root)
507+
import_visitor = FunctionImportedAsVisitor(function_to_optimize)
508+
import_visitor.visit(tree)
509+
func = import_visitor.imported_as
510+
511+
async_instrumenter = AsyncCallInstrumenter(func, test_module_path, test_framework, call_positions, mode=mode)
512+
tree = async_instrumenter.visit(tree)
513+
514+
if not async_instrumenter.did_instrument:
515+
return False, None
516+
517+
# Add necessary imports
518+
new_imports = [ast.Import(names=[ast.alias(name="os")])]
519+
if test_framework == "unittest":
520+
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
521+
522+
tree.body = [*new_imports, *tree.body]
523+
return True, isort.code(ast.unparse(tree), float_to_top=True)
524+
525+
355526
def inject_profiling_into_existing_test(
356527
test_path: Path,
357528
call_positions: list[CodePosition],
@@ -361,7 +532,9 @@ def inject_profiling_into_existing_test(
361532
mode: TestingMode = TestingMode.BEHAVIOR,
362533
) -> tuple[bool, str | None]:
363534
if function_to_optimize.is_async:
364-
return False, None
535+
return inject_async_profiling_into_existing_test(
536+
test_path, call_positions, function_to_optimize, tests_project_root, test_framework, mode
537+
)
365538

366539
with test_path.open(encoding="utf8") as f:
367540
test_code = f.read()

0 commit comments

Comments
 (0)