1212from typing import Dict , List , Optional , Union , cast
1313from urllib .parse import urlparse
1414from string import Template
15+ from azure .ai .evaluation ._common .onedp ._client import AIProjectClient
16+ from azure .core .exceptions import HttpResponseError
1517
1618import jwt
1719
1820from azure .ai .evaluation ._legacy ._adapters ._errors import MissingRequiredPackage
1921from azure .ai .evaluation ._exceptions import ErrorBlame , ErrorCategory , ErrorTarget , EvaluationException
2022from azure .ai .evaluation ._http_utils import AsyncHttpPipeline , get_async_http_client
2123from azure .ai .evaluation ._model_configurations import AzureAIProject
24+ from azure .ai .evaluation ._common .utils import is_onedp_project
2225from azure .core .credentials import TokenCredential
2326from azure .core .exceptions import HttpResponseError
2427from azure .core .pipeline .policies import AsyncRetryPolicy
4144USER_TEXT_TEMPLATE_DICT : Dict [str , Template ] = {
4245 "DEFAULT" : Template ("<Human>{$query}</><System>{$response}</>" ),
4346}
47+ ML_WORKSPACE = "https://management.azure.com/.default"
48+ COG_SRV_WORKSPACE = "https://cognitiveservices.azure.com/.default"
4449
4550INFERENCE_OF_SENSITIVE_ATTRIBUTES = "inference_sensitive_attributes"
4651
@@ -99,11 +104,7 @@ def get_common_headers(token: str, evaluator_name: Optional[str] = None) -> Dict
99104 user_agent = f"{ USER_AGENT } (type=evaluator; subtype={ evaluator_name } )" if evaluator_name else USER_AGENT
100105 return {
101106 "Authorization" : f"Bearer { token } " ,
102- "Content-Type" : "application/json" ,
103107 "User-Agent" : user_agent ,
104- # Handle "RuntimeError: Event loop is closed" from httpx AsyncClient
105- # https://github.com/encode/httpx/discussions/2959
106- "Connection" : "close" ,
107108 }
108109
109110
@@ -112,7 +113,31 @@ def get_async_http_client_with_timeout() -> AsyncHttpPipeline:
112113 retry_policy = AsyncRetryPolicy (timeout = CommonConstants .DEFAULT_HTTP_TIMEOUT )
113114 )
114115
116+ async def ensure_service_availability_onedp (client : AIProjectClient , token : str , capability : Optional [str ] = None ) -> None :
117+ """Check if the Responsible AI service is available in the region and has the required capability, if relevant.
115118
119+ :param client: The AI project client.
120+ :type client: AIProjectClient
121+ :param token: The Azure authentication token.
122+ :type token: str
123+ :param capability: The capability to check. Default is None.
124+ :type capability: str
125+ :raises Exception: If the service is not available in the region or the capability is not available.
126+ """
127+ headers = get_common_headers (token )
128+ capabilities = client .evaluations .check_annotation (headers = headers )
129+
130+ if capability and capability not in capabilities :
131+ msg = f"The needed capability '{ capability } ' is not supported by the RAI service in this region."
132+ raise EvaluationException (
133+ message = msg ,
134+ internal_message = msg ,
135+ target = ErrorTarget .RAI_CLIENT ,
136+ category = ErrorCategory .SERVICE_UNAVAILABLE ,
137+ blame = ErrorBlame .USER_ERROR ,
138+ tsg_link = "https://aka.ms/azsdk/python/evaluation/safetyevaluator/troubleshoot" ,
139+ )
140+
116141async def ensure_service_availability (rai_svc_url : str , token : str , capability : Optional [str ] = None ) -> None :
117142 """Check if the Responsible AI service is available in the region and has the required capability, if relevant.
118143
@@ -231,6 +256,40 @@ async def submit_request(
231256 return operation_id
232257
233258
259+ async def submit_request_onedp (
260+ client : AIProjectClient ,
261+ data : dict ,
262+ metric : str ,
263+ token : str ,
264+ annotation_task : str ,
265+ evaluator_name : str
266+ ) -> str :
267+ """Submit request to Responsible AI service for evaluation and return operation ID
268+
269+ :param client: The AI project client.
270+ :type client: AIProjectClient
271+ :param data: The data to evaluate.
272+ :type data: dict
273+ :param metric: The evaluation metric to use.
274+ :type metric: str
275+ :param token: The Azure authentication token.
276+ :type token: str
277+ :param annotation_task: The annotation task to use.
278+ :type annotation_task: str
279+ :param evaluator_name: The evaluator name.
280+ :type evaluator_name: str
281+ :return: The operation ID.
282+ :rtype: str
283+ """
284+ normalized_user_text = get_formatted_template (data , annotation_task )
285+ payload = generate_payload (normalized_user_text , metric , annotation_task = annotation_task )
286+ headers = get_common_headers (token , evaluator_name )
287+ response = client .evaluations .submit_annotation (payload , headers = headers )
288+ result = json .loads (response )
289+ operation_id = result ["location" ].split ("/" )[- 1 ]
290+ return operation_id
291+
292+
234293async def fetch_result (operation_id : str , rai_svc_url : str , credential : TokenCredential , token : str ) -> Dict :
235294 """Fetch the annotation result from Responsible AI service
236295
@@ -267,6 +326,34 @@ async def fetch_result(operation_id: str, rai_svc_url: str, credential: TokenCre
267326 sleep_time = RAIService .SLEEP_TIME ** request_count
268327 await asyncio .sleep (sleep_time )
269328
329+ async def fetch_result_onedp (client : AIProjectClient , operation_id : str , token : str ) -> Dict :
330+ """Fetch the annotation result from Responsible AI service
331+
332+ :param client: The AI project client.
333+ :type client: AIProjectClient
334+ :param operation_id: The operation ID.
335+ :type operation_id: str
336+ :param token: The Azure authentication token.
337+ :type token: str
338+ :return: The annotation result.
339+ :rtype: Dict
340+ """
341+ start = time .time ()
342+ request_count = 0
343+
344+ while True :
345+ headers = get_common_headers (token )
346+ try :
347+ return client .evaluations .operation_results (operation_id , headers = headers )
348+ except HttpResponseError :
349+ request_count += 1
350+ time_elapsed = time .time () - start
351+ if time_elapsed > RAIService .TIMEOUT :
352+ raise TimeoutError (f"Fetching annotation result { request_count } times out after { time_elapsed :.2f} seconds" )
353+
354+ sleep_time = RAIService .SLEEP_TIME ** request_count
355+ await asyncio .sleep (sleep_time )
356+
270357def parse_response ( # pylint: disable=too-many-branches,too-many-statements
271358 batch_response : List [Dict ], metric_name : str , metric_display_name : Optional [str ] = None
272359) -> Dict [str , Union [str , float ]]:
@@ -500,7 +587,7 @@ async def get_rai_svc_url(project_scope: AzureAIProject, token: str) -> str:
500587 return rai_url
501588
502589
503- async def fetch_or_reuse_token (credential : TokenCredential , token : Optional [str ] = None ) -> str :
590+ async def fetch_or_reuse_token (credential : TokenCredential , token : Optional [str ] = None , workspace : Optional [ str ] = ML_WORKSPACE ) -> str :
504591 """Get token. Fetch a new token if the current token is near expiry
505592
506593 :param credential: The Azure authentication credential.
@@ -524,13 +611,13 @@ async def fetch_or_reuse_token(credential: TokenCredential, token: Optional[str]
524611 if (exp_time - current_time ) >= 300 :
525612 return token
526613
527- return credential .get_token ("https://management.azure.com/.default" ).token
614+ return credential .get_token (workspace ).token
528615
529616
530617async def evaluate_with_rai_service (
531618 data : dict ,
532619 metric_name : str ,
533- project_scope : AzureAIProject ,
620+ project_scope : Union [ str , AzureAIProject ] ,
534621 credential : TokenCredential ,
535622 annotation_task : str = Tasks .CONTENT_HARM ,
536623 metric_display_name = None ,
@@ -556,18 +643,26 @@ async def evaluate_with_rai_service(
556643 :rtype: Dict[str, Union[str, float]]
557644 """
558645
559- # Get RAI service URL from discovery service and check service availability
560- token = await fetch_or_reuse_token (credential )
561- rai_svc_url = await get_rai_svc_url (project_scope , token )
562- await ensure_service_availability (rai_svc_url , token , annotation_task )
563-
564- # Submit annotation request and fetch result
565- operation_id = await submit_request (data , metric_name , rai_svc_url , token , annotation_task , evaluator_name )
566- annotation_response = cast (List [Dict ], await fetch_result (operation_id , rai_svc_url , credential , token ))
567- result = parse_response (annotation_response , metric_name , metric_display_name )
646+ if is_onedp_project (project_scope ):
647+ client = AIProjectClient (endpoint = project_scope , credential = credential )
648+ token = await fetch_or_reuse_token (credential = credential , workspace = COG_SRV_WORKSPACE )
649+ await ensure_service_availability_onedp (client , token , annotation_task )
650+ operation_id = await submit_request_onedp (client , data , metric_name , token , annotation_task , evaluator_name )
651+ annotation_response = cast (List [Dict ], await fetch_result_onedp (client , operation_id , token ))
652+ result = parse_response (annotation_response , metric_name , metric_display_name )
653+ return result
654+ else :
655+ # Get RAI service URL from discovery service and check service availability
656+ token = await fetch_or_reuse_token (credential )
657+ rai_svc_url = await get_rai_svc_url (project_scope , token )
658+ await ensure_service_availability (rai_svc_url , token , annotation_task )
568659
569- return result
660+ # Submit annotation request and fetch result
661+ operation_id = await submit_request (data , metric_name , rai_svc_url , token , annotation_task , evaluator_name )
662+ annotation_response = cast (List [Dict ], await fetch_result (operation_id , rai_svc_url , credential , token ))
663+ result = parse_response (annotation_response , metric_name , metric_display_name )
570664
665+ return result
571666
572667def generate_payload_multimodal (content_type : str , messages , metric : str ) -> Dict :
573668 """Generate the payload for the annotation request
@@ -600,7 +695,6 @@ def generate_payload_multimodal(content_type: str, messages, metric: str) -> Dic
600695 "AnnotationTask" : task ,
601696 }
602697
603-
604698async def submit_multimodal_request (messages , metric : str , rai_svc_url : str , token : str ) -> str :
605699 """Submit request to Responsible AI service for evaluation and return operation ID
606700 :param messages: The normalized list of messages to be entered as the "Contents" in the payload.
@@ -646,9 +740,37 @@ async def submit_multimodal_request(messages, metric: str, rai_svc_url: str, tok
646740 operation_id = result ["location" ].split ("/" )[- 1 ]
647741 return operation_id
648742
743+ async def submit_multimodal_request_onedp (client : AIProjectClient , messages , metric : str , token : str ) -> str :
744+
745+ # handle inference sdk strongly type messages
746+ if len (messages ) > 0 and not isinstance (messages [0 ], dict ):
747+ try :
748+ from azure .ai .inference .models import ChatRequestMessage
749+ except ImportError as ex :
750+ error_message = (
751+ "Please install 'azure-ai-inference' package to use SystemMessage, UserMessage, AssistantMessage"
752+ )
753+ raise MissingRequiredPackage (message = error_message ) from ex
754+ if len (messages ) > 0 and isinstance (messages [0 ], ChatRequestMessage ):
755+ messages = [message .as_dict () for message in messages ]
756+
757+ ## fetch system and assistant messages from the list of messages
758+ filtered_messages = [message for message in messages if message ["role" ] != "system" ]
759+ assistant_messages = [message for message in messages if message ["role" ] == "assistant" ]
760+
761+ ## prepare for request
762+ content_type = retrieve_content_type (assistant_messages , metric )
763+ payload = generate_payload_multimodal (content_type , filtered_messages , metric )
764+ headers = get_common_headers (token )
765+
766+ response = client .evaluations .submit_annotation (payload , headers = headers )
767+
768+ result = json .loads (response )
769+ operation_id = result ["location" ].split ("/" )[- 1 ]
770+ return operation_id
649771
650772async def evaluate_with_rai_service_multimodal (
651- messages , metric_name : str , project_scope : AzureAIProject , credential : TokenCredential
773+ messages , metric_name : str , project_scope : Union [ str , AzureAIProject ] , credential : TokenCredential
652774):
653775 """ "Evaluate the content safety of the response using Responsible AI service
654776 :param messages: The normalized list of messages.
@@ -664,12 +786,20 @@ async def evaluate_with_rai_service_multimodal(
664786 :rtype: List[List[Dict]]
665787 """
666788
667- # Get RAI service URL from discovery service and check service availability
668- token = await fetch_or_reuse_token (credential )
669- rai_svc_url = await get_rai_svc_url (project_scope , token )
670- await ensure_service_availability (rai_svc_url , token , Tasks .CONTENT_HARM )
671- # Submit annotation request and fetch result
672- operation_id = await submit_multimodal_request (messages , metric_name , rai_svc_url , token )
673- annotation_response = cast (List [Dict ], await fetch_result (operation_id , rai_svc_url , credential , token ))
674- result = parse_response (annotation_response , metric_name )
675- return result
789+ if is_onedp_project (project_scope ):
790+ client = AIProjectClient (endpoint = project_scope , credential = credential )
791+ token = await fetch_or_reuse_token (credential = credential , workspace = COG_SRV_WORKSPACE )
792+ await ensure_service_availability_onedp (client , token , Tasks .CONTENT_HARM )
793+ operation_id = await submit_multimodal_request_onedp (client , messages , metric_name , token )
794+ annotation_response = cast (List [Dict ], await fetch_result_onedp (client , operation_id , token ))
795+ result = parse_response (annotation_response , metric_name )
796+ return result
797+ else :
798+ token = await fetch_or_reuse_token (credential )
799+ rai_svc_url = await get_rai_svc_url (project_scope , token )
800+ await ensure_service_availability (rai_svc_url , token , Tasks .CONTENT_HARM )
801+ # Submit annotation request and fetch result
802+ operation_id = await submit_multimodal_request (messages , metric_name , rai_svc_url , token )
803+ annotation_response = cast (List [Dict ], await fetch_result (operation_id , rai_svc_url , credential , token ))
804+ result = parse_response (annotation_response , metric_name )
805+ return result
0 commit comments