77import re
88import time
99from ast import literal_eval
10- from typing import Any , Dict , List , Optional , Union , cast
10+ from typing import Dict , List , Optional , Union , cast
1111from urllib .parse import urlparse
12+ from string import Template
1213
1314import jwt
1415
2324 EvaluationMetrics ,
2425 RAIService ,
2526 Tasks ,
26- _InternalAnnotationTasks ,
2727 _InternalEvaluationMetrics ,
2828)
2929from .utils import get_harm_severity_level
3434 version = "unknown"
3535USER_AGENT = "{}/{}" .format ("azure-ai-evaluation" , version )
3636
37+ USER_TEXT_TEMPLATE_DICT : Dict [str , Template ] = {
38+ "DEFAULT" : Template ("<Human>{$query}</><System>{$response}</>" ),
39+ Tasks .GROUNDEDNESS : Template ('{"question": "$query", "answer": "$response", "context": "$context"}' ),
40+ }
41+
3742
3843def get_common_headers (token : str ) -> Dict :
3944 """Get common headers for the HTTP request
@@ -99,27 +104,26 @@ async def ensure_service_availability(rai_svc_url: str, token: str, capability:
99104 )
100105
101106
102- def generate_payload (normalized_user_text : str , metric : str ) -> Dict :
107+ def generate_payload (normalized_user_text : str , metric : str , annotation_task : str ) -> Dict :
103108 """Generate the payload for the annotation request
104109
105110 :param normalized_user_text: The normalized user text to be entered as the "UserTextList" in the payload.
106111 :type normalized_user_text: str
107112 :param metric: The evaluation metric to use. This determines the task type, and whether a "MetricList" is needed
108113 in the payload.
109114 :type metric: str
115+ :param annotation_task: The annotation task to be passed to service
116+ :type annotation_task: str
110117 :return: The payload for the annotation request.
111118 :rtype: Dict
112119 """
113120 include_metric = True
114- task = Tasks . CONTENT_HARM
121+ task = annotation_task
115122 if metric == EvaluationMetrics .PROTECTED_MATERIAL :
116- task = Tasks .PROTECTED_MATERIAL
117123 include_metric = False
118124 elif metric == _InternalEvaluationMetrics .ECI :
119- task = _InternalAnnotationTasks .ECI
120125 include_metric = False
121126 elif metric == EvaluationMetrics .XPIA :
122- task = Tasks .XPIA
123127 include_metric = False
124128 return (
125129 {
@@ -135,25 +139,25 @@ def generate_payload(normalized_user_text: str, metric: str) -> Dict:
135139 )
136140
137141
138- async def submit_request (query : str , response : str , metric : str , rai_svc_url : str , token : str ) -> str :
142+ async def submit_request (data : dict , metric : str , rai_svc_url : str , token : str , annotation_task : str ) -> str :
139143 """Submit request to Responsible AI service for evaluation and return operation ID
140144
141- :param query: The query to evaluate.
142- :type query: str
143- :param response: The response to evaluate.
144- :type response: str
145+ :param data: The data to evaluate.
146+ :type data: dict
145147 :param metric: The evaluation metric to use.
146148 :type metric: str
147149 :param rai_svc_url: The Responsible AI service URL.
148150 :type rai_svc_url: str
149151 :param token: The Azure authentication token.
150152 :type token: str
153+ :param annotation_task: The annotation task to use.
154+ :type annotation_task: str
151155 :return: The operation ID.
152156 :rtype: str
153157 """
154- user_text = f"<Human> { query } </><System> { response } </>"
158+ user_text = USER_TEXT_TEMPLATE_DICT . get ( annotation_task , USER_TEXT_TEMPLATE_DICT [ "DEFAULT" ]). substitute ( ** data )
155159 normalized_user_text = user_text .replace ("'" , '\\ "' )
156- payload = generate_payload (normalized_user_text , metric )
160+ payload = generate_payload (normalized_user_text , metric , annotation_task = annotation_task )
157161
158162 url = rai_svc_url + "/submitannotation"
159163 headers = get_common_headers (token )
@@ -164,7 +168,6 @@ async def submit_request(query: str, response: str, metric: str, rai_svc_url: st
164168 if http_response .status_code != 202 :
165169 print ("Fail evaluating '%s' with error message: %s" % (payload ["UserTextList" ], http_response .text ()))
166170 http_response .raise_for_status ()
167-
168171 result = http_response .json ()
169172 operation_id = result ["location" ].split ("/" )[- 1 ]
170173 return operation_id
@@ -208,19 +211,28 @@ async def fetch_result(operation_id: str, rai_svc_url: str, credential: TokenCre
208211
209212
210213def parse_response ( # pylint: disable=too-many-branches,too-many-statements
211- batch_response : List [Dict ], metric_name : str
214+ batch_response : List [Dict ], metric_name : str , metric_display_name : Optional [ str ] = None
212215) -> Dict [str , Union [str , float ]]:
213216 """Parse the annotation response from Responsible AI service for a content harm evaluation.
214217
215218 :param batch_response: The annotation response from Responsible AI service.
216219 :type batch_response: List[Dict]
217220 :param metric_name: The evaluation metric to use.
218221 :type metric_name: str
222+ :param metric_display_name: The evaluation metric display name to use. If unset, use the metric_name.
223+ :type metric_display_name: Optional[str]
219224 :return: The parsed annotation result.
220225 :rtype: Dict[str, Union[str, float]]
221226 """
227+ if metric_display_name is None :
228+ metric_display_name = metric_name
229+
222230 # non-numeric metrics
223- if metric_name in {EvaluationMetrics .PROTECTED_MATERIAL , _InternalEvaluationMetrics .ECI , EvaluationMetrics .XPIA }:
231+ if metric_name in {
232+ EvaluationMetrics .PROTECTED_MATERIAL ,
233+ _InternalEvaluationMetrics .ECI ,
234+ EvaluationMetrics .XPIA ,
235+ }:
224236 if not batch_response or len (batch_response [0 ]) == 0 or metric_name not in batch_response [0 ]:
225237 return {}
226238 response = batch_response [0 ][metric_name ]
@@ -230,38 +242,42 @@ def parse_response( # pylint: disable=too-many-branches,too-many-statements
230242 result = {}
231243 # Use label instead of score since these are assumed to be boolean results.
232244 # Use math.nan as null value since it's ignored by aggregations rather than treated as 0.
233- result [metric_name + "_label" ] = parsed_response ["label" ] if "label" in parsed_response else math .nan
234- result [metric_name + "_reason" ] = parsed_response ["reasoning" ] if "reasoning" in parsed_response else ""
245+ result [metric_display_name + "_label" ] = parsed_response ["label" ] if "label" in parsed_response else math .nan
246+ result [metric_display_name + "_reason" ] = parsed_response ["reasoning" ] if "reasoning" in parsed_response else ""
235247
236248 if metric_name == EvaluationMetrics .XPIA :
237249 # Add "manipulated_content", "intrusion" and "information_gathering" to the result
238250 # if present else set them to math.nan
239- result [metric_name + "_manipulated_content" ] = (
251+ result [metric_display_name + "_manipulated_content" ] = (
240252 parsed_response ["manipulated_content" ] if "manipulated_content" in parsed_response else math .nan
241253 )
242- result [metric_name + "_intrusion" ] = (
254+ result [metric_display_name + "_intrusion" ] = (
243255 parsed_response ["intrusion" ] if "intrusion" in parsed_response else math .nan
244256 )
245- result [metric_name + "_information_gathering" ] = (
257+ result [metric_display_name + "_information_gathering" ] = (
246258 parsed_response ["information_gathering" ] if "information_gathering" in parsed_response else math .nan
247259 )
248260 return result
249- return _parse_content_harm_response (batch_response , metric_name )
261+ return _parse_content_harm_response (batch_response , metric_name , metric_display_name )
250262
251263
252- def _parse_content_harm_response (batch_response : List [Dict ], metric_name : str ) -> Dict [str , Union [str , float ]]:
264+ def _parse_content_harm_response (
265+ batch_response : List [Dict ], metric_name : str , metric_display_name : Optional [str ] = None
266+ ) -> Dict [str , Union [str , float ]]:
253267 """Parse the annotation response from Responsible AI service for a content harm evaluation.
254268
255269 :param batch_response: The annotation response from Responsible AI service.
256270 :type batch_response: List[Dict]
257271 :param metric_name: The evaluation metric to use.
258272 :type metric_name: str
273+ :param metric_display_name: The evaluation metric display name to use. If unset, use the metric_name.
274+ :type metric_display_name: Optional[str]
259275 :return: The parsed annotation result.
260276 :rtype: Dict[str, Union[str, float]]
261277 """
262278 # Fix the metric name if it's "hate_fairness"
263279 # Eventually we will remove this fix once the RAI service is updated
264- key = metric_name
280+ key = metric_name if metric_display_name is None else metric_display_name
265281 if key == EvaluationMetrics .HATE_FAIRNESS :
266282 key = EvaluationMetrics .HATE_UNFAIRNESS
267283
@@ -283,7 +299,7 @@ def _parse_content_harm_response(batch_response: List[Dict], metric_name: str) -
283299
284300 # get content harm metric_value
285301 if "label" in harm_response :
286- metric_value = harm_response ["label" ]
302+ metric_value = float ( harm_response ["label" ])
287303 elif "valid" in harm_response :
288304 metric_value = 0 if harm_response ["valid" ] else math .nan
289305 else :
@@ -412,33 +428,40 @@ async def fetch_or_reuse_token(credential: TokenCredential, token: Optional[str]
412428
413429
414430async def evaluate_with_rai_service (
415- query : str , response : str , metric_name : str , project_scope : AzureAIProject , credential : TokenCredential
416- ) -> Dict [str , Any ]:
431+ data : dict ,
432+ metric_name : str ,
433+ project_scope : AzureAIProject ,
434+ credential : TokenCredential ,
435+ annotation_task : str = Tasks .CONTENT_HARM ,
436+ metric_display_name = None ,
437+ ) -> Dict [str , Union [str , float ]]:
417438 """ "Evaluate the content safety of the response using Responsible AI service
418439
419- :param query: The query to evaluate.
420- :type query: str
421- :param response: The response to evaluate.
422- :type response: str
440+ :param data: The data to evaluate.
441+ :type data: dict
423442 :param metric_name: The evaluation metric to use.
424443 :type metric_name: str
425444 :param project_scope: The Azure AI project scope details.
426445 :type project_scope: Dict
427446 :param credential: The Azure authentication credential.
428447 :type credential:
429448 ~azure.core.credentials.TokenCredential
449+ :param annotation_task: The annotation task to use.
450+ :type annotation_task: str
451+ :param metric_display_name: The display name of metric to use.
452+ :type metric_display_name: str
430453 :return: The parsed annotation result.
431454 :rtype: Dict[str, Union[str, float]]
432455 """
433456
434457 # Get RAI service URL from discovery service and check service availability
435458 token = await fetch_or_reuse_token (credential )
436459 rai_svc_url = await get_rai_svc_url (project_scope , token )
437- await ensure_service_availability (rai_svc_url , token , Tasks . CONTENT_HARM )
460+ await ensure_service_availability (rai_svc_url , token , annotation_task )
438461
439462 # Submit annotation request and fetch result
440- operation_id = await submit_request (query , response , metric_name , rai_svc_url , token )
463+ operation_id = await submit_request (data , metric_name , rai_svc_url , token , annotation_task )
441464 annotation_response = cast (List [Dict ], await fetch_result (operation_id , rai_svc_url , credential , token ))
442- result = parse_response (annotation_response , metric_name )
465+ result = parse_response (annotation_response , metric_name , metric_display_name )
443466
444467 return result
0 commit comments