66import os
77import inspect
88import logging
9+ import asyncio
910from datetime import datetime
1011from azure .ai .evaluation ._common ._experimental import experimental
11- from typing import Any , Callable , Dict , List , Optional , Union , cast
12+ from typing import Any , Callable , Dict , List , Optional , Union , cast , Coroutine , TypeVar , Awaitable
1213from azure .ai .evaluation ._common .math import list_mean_nan_safe
1314from azure .ai .evaluation ._constants import CONTENT_SAFETY_DEFECT_RATE_THRESHOLD_DEFAULT
1415from azure .ai .evaluation ._evaluators import (
@@ -192,10 +193,17 @@ async def callback(
192193 context = latest_message .get ("context" , None )
193194 latest_context = None
194195 try :
196+ is_async = self ._is_async_function (target )
195197 if self ._check_target_returns_context (target ):
196- response , latest_context = target (query = application_input )
198+ if is_async :
199+ response , latest_context = await target (query = application_input )
200+ else :
201+ response , latest_context = target (query = application_input )
197202 else :
198- response = target (query = application_input )
203+ if is_async :
204+ response = await target (query = application_input )
205+ else :
206+ response = target (query = application_input )
199207 except Exception as e :
200208 response = f"Something went wrong { e !s} "
201209
@@ -465,7 +473,7 @@ def _get_evaluators(
465473 blame = ErrorBlame .USER_ERROR ,
466474 )
467475 return evaluators_dict
468-
476+
469477 @staticmethod
470478 def _check_target_returns_context (target : Callable ) -> bool :
471479 """
@@ -478,6 +486,15 @@ def _check_target_returns_context(target: Callable) -> bool:
478486 ret_type = sig .return_annotation
479487 if ret_type == inspect .Signature .empty :
480488 return False
489+
490+ # Check for Coroutine/Awaitable return types for async functions
491+ origin = getattr (ret_type , "__origin__" , None )
492+ if origin is not None and (origin is Coroutine or origin is Awaitable ):
493+ args = getattr (ret_type , "__args__" , None )
494+ if args and len (args ) > 0 :
495+ # For async functions, check the actual return type inside the Coroutine
496+ ret_type = args [- 1 ]
497+
481498 if ret_type is tuple :
482499 return True
483500 return False
@@ -494,13 +511,33 @@ def _check_target_returns_str(target: Callable) -> bool:
494511 ret_type = sig .return_annotation
495512 if ret_type == inspect .Signature .empty :
496513 return False
514+
515+ # Check for Coroutine/Awaitable return types for async functions
516+ origin = getattr (ret_type , "__origin__" , None )
517+ if origin is not None and (origin is Coroutine or origin is Awaitable ):
518+ args = getattr (ret_type , "__args__" , None )
519+ if args and len (args ) > 0 :
520+ # For async functions, check the actual return type inside the Coroutine
521+ ret_type = args [- 1 ]
522+
497523 if ret_type is str :
498524 return True
499525 return False
500526
501-
502527 @staticmethod
503- def _check_target_is_callback (target :Callable ) -> bool :
528+ def _is_async_function (target : Callable ) -> bool :
529+ """
530+ Checks if the target function is an async function.
531+
532+ :param target: The target function to check.
533+ :type target: Callable
534+ :return: True if the target function is async, False otherwise.
535+ :rtype: bool
536+ """
537+ return asyncio .iscoroutinefunction (target )
538+
539+ @staticmethod
540+ def _check_target_is_callback (target : Callable ) -> bool :
504541 sig = inspect .signature (target )
505542 param_names = list (sig .parameters .keys ())
506543 return 'messages' in param_names and 'stream' in param_names and 'session_state' in param_names and 'context' in param_names
@@ -630,7 +667,7 @@ def _calculate_defect_rate(self, evaluation_result_dict) -> EvaluationResult:
630667
631668 async def __call__ (
632669 self ,
633- target : Union [Callable , AzureOpenAIModelConfiguration , OpenAIModelConfiguration ],
670+ target : Union [Callable , Awaitable [ Any ], AzureOpenAIModelConfiguration , OpenAIModelConfiguration ],
634671 evaluators : List [_SafetyEvaluator ] = [],
635672 evaluation_name : Optional [str ] = None ,
636673 num_turns : int = 1 ,
@@ -644,12 +681,12 @@ async def __call__(
644681 jailbreak_data_path : Optional [Union [str , os .PathLike ]] = None ,
645682 output_path : Optional [Union [str , os .PathLike ]] = None ,
646683 data_paths : Optional [Union [Dict [str , str ], Dict [str , Union [str ,os .PathLike ]]]] = None
647- ) -> Union [Dict [str , EvaluationResult ], Dict [str , str ], Dict [str , Union [str ,os .PathLike ]]]:
684+ ) -> Union [Dict [str , EvaluationResult ], Dict [str , str ], Dict [str , Union [str ,os .PathLike ]]]:
648685 '''
649686 Evaluates the target function based on the provided parameters.
650687
651- :param target: The target function to call during the evaluation.
652- :type target: Callable
688+ :param target: The target function to call during the evaluation. This can be a synchronous or asynchronous function.
689+ :type target: Union[ Callable, Awaitable[Any], AzureOpenAIModelConfiguration, OpenAIModelConfiguration]
653690 :param evaluators: A list of SafetyEvaluator.
654691 :type evaluators: List[_SafetyEvaluator]
655692 :param evaluation_name: The display name name of the evaluation.
0 commit comments