Skip to content

Commit 578b16c

Browse files
MilesHollandsingankitneeduv
authored
Add groundedness pro eval (Azure#38063)
* Adding service based groundedness * groundedness pro eval * remove groundedness and fix unit tests * run black * change evaluate label * Update sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py Co-authored-by: Neehar Duvvuri <[email protected]> * Update sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py Co-authored-by: Neehar Duvvuri <[email protected]> * comments and CL * re record tests * black and pylint * comments * nits * analysis * re cast * more mypy appeasement --------- Co-authored-by: Ankit Singhal <[email protected]> Co-authored-by: Neehar Duvvuri <[email protected]>
1 parent 383b5cd commit 578b16c

File tree

25 files changed

+403
-75
lines changed

25 files changed

+403
-75
lines changed

sdk/evaluation/azure-ai-evaluation/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
## 1.0.0b5 (Unreleased)
55

66
### Features Added
7+
- Added `GroundednessProEvaluator`, which is a service-based evaluator for determining response groundedness.
78
- Groundedness detection in Non Adversarial Simulator via query/context pairs
89
```python
910
import importlib.resources as pkg_resources

sdk/evaluation/azure-ai-evaluation/assets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "python",
44
"TagPrefix": "python/evaluation/azure-ai-evaluation",
5-
"Tag": "python/evaluation/azure-ai-evaluation_1390701e9d"
5+
"Tag": "python/evaluation/azure-ai-evaluation_5551827d25"
66
}

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ._evaluators._fluency import FluencyEvaluator
1717
from ._evaluators._gleu import GleuScoreEvaluator
1818
from ._evaluators._groundedness import GroundednessEvaluator
19+
from ._evaluators._service_groundedness import GroundednessProEvaluator
1920
from ._evaluators._meteor import MeteorScoreEvaluator
2021
from ._evaluators._protected_material import ProtectedMaterialEvaluator
2122
from ._evaluators._qa import QAEvaluator
@@ -40,6 +41,7 @@
4041
"F1ScoreEvaluator",
4142
"FluencyEvaluator",
4243
"GroundednessEvaluator",
44+
"GroundednessProEvaluator",
4345
"RelevanceEvaluator",
4446
"SimilarityEvaluator",
4547
"QAEvaluator",

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_common/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class Tasks:
3535
CONTENT_HARM = "content harm"
3636
PROTECTED_MATERIAL = "protected material"
3737
XPIA = "xpia"
38+
GROUNDEDNESS = "groundedness"
3839

3940

4041
class _InternalAnnotationTasks:
@@ -56,6 +57,7 @@ class EvaluationMetrics(str, Enum, metaclass=CaseInsensitiveEnumMeta):
5657
SEXUAL = "sexual"
5758
PROTECTED_MATERIAL = "protected_material"
5859
XPIA = "xpia"
60+
GROUNDEDNESS = "generic_groundedness"
5961

6062

6163
class _InternalEvaluationMetrics(str, Enum, metaclass=CaseInsensitiveEnumMeta):

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_common/rai_service.py

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import re
88
import time
99
from ast import literal_eval
10-
from typing import Any, Dict, List, Optional, Union, cast
10+
from typing import Dict, List, Optional, Union, cast
1111
from urllib.parse import urlparse
12+
from string import Template
1213

1314
import jwt
1415

@@ -23,7 +24,6 @@
2324
EvaluationMetrics,
2425
RAIService,
2526
Tasks,
26-
_InternalAnnotationTasks,
2727
_InternalEvaluationMetrics,
2828
)
2929
from .utils import get_harm_severity_level
@@ -34,6 +34,11 @@
3434
version = "unknown"
3535
USER_AGENT = "{}/{}".format("azure-ai-evaluation", version)
3636

37+
USER_TEXT_TEMPLATE_DICT: Dict[str, Template] = {
38+
"DEFAULT": Template("<Human>{$query}</><System>{$response}</>"),
39+
Tasks.GROUNDEDNESS: Template('{"question": "$query", "answer": "$response", "context": "$context"}'),
40+
}
41+
3742

3843
def get_common_headers(token: str) -> Dict:
3944
"""Get common headers for the HTTP request
@@ -99,27 +104,26 @@ async def ensure_service_availability(rai_svc_url: str, token: str, capability:
99104
)
100105

101106

102-
def generate_payload(normalized_user_text: str, metric: str) -> Dict:
107+
def generate_payload(normalized_user_text: str, metric: str, annotation_task: str) -> Dict:
103108
"""Generate the payload for the annotation request
104109
105110
:param normalized_user_text: The normalized user text to be entered as the "UserTextList" in the payload.
106111
:type normalized_user_text: str
107112
:param metric: The evaluation metric to use. This determines the task type, and whether a "MetricList" is needed
108113
in the payload.
109114
:type metric: str
115+
:param annotation_task: The annotation task to be passed to service
116+
:type annotation_task: str
110117
:return: The payload for the annotation request.
111118
:rtype: Dict
112119
"""
113120
include_metric = True
114-
task = Tasks.CONTENT_HARM
121+
task = annotation_task
115122
if metric == EvaluationMetrics.PROTECTED_MATERIAL:
116-
task = Tasks.PROTECTED_MATERIAL
117123
include_metric = False
118124
elif metric == _InternalEvaluationMetrics.ECI:
119-
task = _InternalAnnotationTasks.ECI
120125
include_metric = False
121126
elif metric == EvaluationMetrics.XPIA:
122-
task = Tasks.XPIA
123127
include_metric = False
124128
return (
125129
{
@@ -135,25 +139,25 @@ def generate_payload(normalized_user_text: str, metric: str) -> Dict:
135139
)
136140

137141

138-
async def submit_request(query: str, response: str, metric: str, rai_svc_url: str, token: str) -> str:
142+
async def submit_request(data: dict, metric: str, rai_svc_url: str, token: str, annotation_task: str) -> str:
139143
"""Submit request to Responsible AI service for evaluation and return operation ID
140144
141-
:param query: The query to evaluate.
142-
:type query: str
143-
:param response: The response to evaluate.
144-
:type response: str
145+
:param data: The data to evaluate.
146+
:type data: dict
145147
:param metric: The evaluation metric to use.
146148
:type metric: str
147149
:param rai_svc_url: The Responsible AI service URL.
148150
:type rai_svc_url: str
149151
:param token: The Azure authentication token.
150152
:type token: str
153+
:param annotation_task: The annotation task to use.
154+
:type annotation_task: str
151155
:return: The operation ID.
152156
:rtype: str
153157
"""
154-
user_text = f"<Human>{query}</><System>{response}</>"
158+
user_text = USER_TEXT_TEMPLATE_DICT.get(annotation_task, USER_TEXT_TEMPLATE_DICT["DEFAULT"]).substitute(**data)
155159
normalized_user_text = user_text.replace("'", '\\"')
156-
payload = generate_payload(normalized_user_text, metric)
160+
payload = generate_payload(normalized_user_text, metric, annotation_task=annotation_task)
157161

158162
url = rai_svc_url + "/submitannotation"
159163
headers = get_common_headers(token)
@@ -164,7 +168,6 @@ async def submit_request(query: str, response: str, metric: str, rai_svc_url: st
164168
if http_response.status_code != 202:
165169
print("Fail evaluating '%s' with error message: %s" % (payload["UserTextList"], http_response.text()))
166170
http_response.raise_for_status()
167-
168171
result = http_response.json()
169172
operation_id = result["location"].split("/")[-1]
170173
return operation_id
@@ -208,19 +211,28 @@ async def fetch_result(operation_id: str, rai_svc_url: str, credential: TokenCre
208211

209212

210213
def parse_response( # pylint: disable=too-many-branches,too-many-statements
211-
batch_response: List[Dict], metric_name: str
214+
batch_response: List[Dict], metric_name: str, metric_display_name: Optional[str] = None
212215
) -> Dict[str, Union[str, float]]:
213216
"""Parse the annotation response from Responsible AI service for a content harm evaluation.
214217
215218
:param batch_response: The annotation response from Responsible AI service.
216219
:type batch_response: List[Dict]
217220
:param metric_name: The evaluation metric to use.
218221
:type metric_name: str
222+
:param metric_display_name: The evaluation metric display name to use. If unset, use the metric_name.
223+
:type metric_display_name: Optional[str]
219224
:return: The parsed annotation result.
220225
:rtype: Dict[str, Union[str, float]]
221226
"""
227+
if metric_display_name is None:
228+
metric_display_name = metric_name
229+
222230
# non-numeric metrics
223-
if metric_name in {EvaluationMetrics.PROTECTED_MATERIAL, _InternalEvaluationMetrics.ECI, EvaluationMetrics.XPIA}:
231+
if metric_name in {
232+
EvaluationMetrics.PROTECTED_MATERIAL,
233+
_InternalEvaluationMetrics.ECI,
234+
EvaluationMetrics.XPIA,
235+
}:
224236
if not batch_response or len(batch_response[0]) == 0 or metric_name not in batch_response[0]:
225237
return {}
226238
response = batch_response[0][metric_name]
@@ -230,38 +242,42 @@ def parse_response( # pylint: disable=too-many-branches,too-many-statements
230242
result = {}
231243
# Use label instead of score since these are assumed to be boolean results.
232244
# Use math.nan as null value since it's ignored by aggregations rather than treated as 0.
233-
result[metric_name + "_label"] = parsed_response["label"] if "label" in parsed_response else math.nan
234-
result[metric_name + "_reason"] = parsed_response["reasoning"] if "reasoning" in parsed_response else ""
245+
result[metric_display_name + "_label"] = parsed_response["label"] if "label" in parsed_response else math.nan
246+
result[metric_display_name + "_reason"] = parsed_response["reasoning"] if "reasoning" in parsed_response else ""
235247

236248
if metric_name == EvaluationMetrics.XPIA:
237249
# Add "manipulated_content", "intrusion" and "information_gathering" to the result
238250
# if present else set them to math.nan
239-
result[metric_name + "_manipulated_content"] = (
251+
result[metric_display_name + "_manipulated_content"] = (
240252
parsed_response["manipulated_content"] if "manipulated_content" in parsed_response else math.nan
241253
)
242-
result[metric_name + "_intrusion"] = (
254+
result[metric_display_name + "_intrusion"] = (
243255
parsed_response["intrusion"] if "intrusion" in parsed_response else math.nan
244256
)
245-
result[metric_name + "_information_gathering"] = (
257+
result[metric_display_name + "_information_gathering"] = (
246258
parsed_response["information_gathering"] if "information_gathering" in parsed_response else math.nan
247259
)
248260
return result
249-
return _parse_content_harm_response(batch_response, metric_name)
261+
return _parse_content_harm_response(batch_response, metric_name, metric_display_name)
250262

251263

252-
def _parse_content_harm_response(batch_response: List[Dict], metric_name: str) -> Dict[str, Union[str, float]]:
264+
def _parse_content_harm_response(
265+
batch_response: List[Dict], metric_name: str, metric_display_name: Optional[str] = None
266+
) -> Dict[str, Union[str, float]]:
253267
"""Parse the annotation response from Responsible AI service for a content harm evaluation.
254268
255269
:param batch_response: The annotation response from Responsible AI service.
256270
:type batch_response: List[Dict]
257271
:param metric_name: The evaluation metric to use.
258272
:type metric_name: str
273+
:param metric_display_name: The evaluation metric display name to use. If unset, use the metric_name.
274+
:type metric_display_name: Optional[str]
259275
:return: The parsed annotation result.
260276
:rtype: Dict[str, Union[str, float]]
261277
"""
262278
# Fix the metric name if it's "hate_fairness"
263279
# Eventually we will remove this fix once the RAI service is updated
264-
key = metric_name
280+
key = metric_name if metric_display_name is None else metric_display_name
265281
if key == EvaluationMetrics.HATE_FAIRNESS:
266282
key = EvaluationMetrics.HATE_UNFAIRNESS
267283

@@ -283,7 +299,7 @@ def _parse_content_harm_response(batch_response: List[Dict], metric_name: str) -
283299

284300
# get content harm metric_value
285301
if "label" in harm_response:
286-
metric_value = harm_response["label"]
302+
metric_value = float(harm_response["label"])
287303
elif "valid" in harm_response:
288304
metric_value = 0 if harm_response["valid"] else math.nan
289305
else:
@@ -412,33 +428,40 @@ async def fetch_or_reuse_token(credential: TokenCredential, token: Optional[str]
412428

413429

414430
async def evaluate_with_rai_service(
415-
query: str, response: str, metric_name: str, project_scope: AzureAIProject, credential: TokenCredential
416-
) -> Dict[str, Any]:
431+
data: dict,
432+
metric_name: str,
433+
project_scope: AzureAIProject,
434+
credential: TokenCredential,
435+
annotation_task: str = Tasks.CONTENT_HARM,
436+
metric_display_name=None,
437+
) -> Dict[str, Union[str, float]]:
417438
""" "Evaluate the content safety of the response using Responsible AI service
418439
419-
:param query: The query to evaluate.
420-
:type query: str
421-
:param response: The response to evaluate.
422-
:type response: str
440+
:param data: The data to evaluate.
441+
:type data: dict
423442
:param metric_name: The evaluation metric to use.
424443
:type metric_name: str
425444
:param project_scope: The Azure AI project scope details.
426445
:type project_scope: Dict
427446
:param credential: The Azure authentication credential.
428447
:type credential:
429448
~azure.core.credentials.TokenCredential
449+
:param annotation_task: The annotation task to use.
450+
:type annotation_task: str
451+
:param metric_display_name: The display name of metric to use.
452+
:type metric_display_name: str
430453
:return: The parsed annotation result.
431454
:rtype: Dict[str, Union[str, float]]
432455
"""
433456

434457
# Get RAI service URL from discovery service and check service availability
435458
token = await fetch_or_reuse_token(credential)
436459
rai_svc_url = await get_rai_svc_url(project_scope, token)
437-
await ensure_service_availability(rai_svc_url, token, Tasks.CONTENT_HARM)
460+
await ensure_service_availability(rai_svc_url, token, annotation_task)
438461

439462
# Submit annotation request and fetch result
440-
operation_id = await submit_request(query, response, metric_name, rai_svc_url, token)
463+
operation_id = await submit_request(data, metric_name, rai_svc_url, token, annotation_task)
441464
annotation_response = cast(List[Dict], await fetch_result(operation_id, rai_svc_url, credential, token))
442-
result = parse_response(annotation_response, metric_name)
465+
result = parse_response(annotation_response, metric_name, metric_display_name)
443466

444467
return result

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
T_TypedDict = TypeVar("T_TypedDict", bound=Mapping[Any, Any])
2525

2626

27-
def get_harm_severity_level(harm_score: int) -> Union[str, float]:
27+
def get_harm_severity_level(harm_score: Union[float, int]) -> Union[str, float]:
2828
"""Generate harm severity level based on harm score.
2929
3030
:param harm_score: The harm score to be evaluated.

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_evaluate.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,47 @@
3636

3737
TClient = TypeVar("TClient", ProxyClient, CodeClient)
3838

39+
# For metrics (aggregates) whose metric names intentionally differ from their
40+
# originating column name, usually because the aggregation of the original value
41+
# means something sufficiently different.
42+
# Note that content safety metrics are handled seprately.
43+
METRIC_COLUMN_NAME_REPLACEMENTS = {
44+
"groundedness_pro_label": "groundedness_pro_passing_rate",
45+
}
46+
3947

4048
class __EvaluatorInfo(TypedDict):
4149
result: pd.DataFrame
4250
metrics: Dict[str, Any]
4351
run_summary: Dict[str, Any]
4452

4553

54+
def _aggregate_other_metrics(df: pd.DataFrame) -> Tuple[List[str], Dict[str, float]]:
55+
"""Identify and average various metrics that need to have the metric name be replaced,
56+
instead of having the metric match the originating column name.
57+
:param df: The dataframe of evaluation results.
58+
:type df: ~pandas.DataFrame
59+
:return: A tuple; the first element is a list of dataframe columns that were aggregated,
60+
and the second element is a dictionary of resultant new metric column names and their values.
61+
:rtype: Tuple[List[str], Dict[str, float]]
62+
"""
63+
renamed_cols = []
64+
metric_columns = {}
65+
for col in df.columns:
66+
metric_prefix = col.split(".")[0]
67+
metric_name = col.split(".")[1]
68+
if metric_name in METRIC_COLUMN_NAME_REPLACEMENTS:
69+
renamed_cols.append(col)
70+
new_col_name = metric_prefix + "." + METRIC_COLUMN_NAME_REPLACEMENTS[metric_name]
71+
col_with_numeric_values = pd.to_numeric(df[col], errors="coerce")
72+
metric_columns[new_col_name] = round(
73+
list_sum(col_with_numeric_values) / col_with_numeric_values.count(),
74+
2,
75+
)
76+
77+
return renamed_cols, metric_columns
78+
79+
4680
# pylint: disable=line-too-long
4781
def _aggregate_content_safety_metrics(
4882
df: pd.DataFrame, evaluators: Dict[str, Callable]
@@ -146,8 +180,11 @@ def _aggregate_metrics(df: pd.DataFrame, evaluators: Dict[str, Callable]) -> Dic
146180
# Rename certain columns as defect rates if we know that's what their aggregates represent
147181
# Content safety metrics
148182
content_safety_cols, cs_defect_rates = _aggregate_content_safety_metrics(df, evaluators)
183+
other_renamed_cols, renamed_cols = _aggregate_other_metrics(df)
149184
handled_columns.extend(content_safety_cols)
185+
handled_columns.extend(other_renamed_cols)
150186
defect_rates.update(cs_defect_rates)
187+
defect_rates.update(renamed_cols)
151188
# Label-based (true/false) metrics where 'true' means 'something is wrong'
152189
label_cols, label_defect_rates = _aggregate_label_defect_metrics(df)
153190
handled_columns.extend(label_cols)

0 commit comments

Comments
 (0)