Skip to content

Commit 7420394

Browse files
abhahnAbby Hartman
andauthored
DocumentRetrievalEvaluator: Small fixes for importing, threshold setting and metrics output (Azure#40929)
* Small fixes for importing, threshold setting and metrics output * Updated threshold test * Update eval mapping --------- Co-authored-by: Abby Hartman <[email protected]>
1 parent 3441543 commit 7420394

File tree

4 files changed

+35
-41
lines changed

4 files changed

+35
-41
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ._evaluators._code_vulnerability import CodeVulnerabilityEvaluator
3232
from ._evaluators._ungrounded_attributes import UngroundedAttributesEvaluator
3333
from ._evaluators._tool_call_accuracy import ToolCallAccuracyEvaluator
34+
from ._evaluators._document_retrieval import DocumentRetrievalEvaluator
3435
from ._model_configurations import (
3536
AzureAIProject,
3637
AzureOpenAIModelConfiguration,

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_eval_mapping.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
CodeVulnerabilityEvaluator,
1717
CoherenceEvaluator,
1818
ContentSafetyEvaluator,
19+
DocumentRetrievalEvaluator,
1920
F1ScoreEvaluator,
2021
FluencyEvaluator,
2122
GleuScoreEvaluator,
@@ -45,6 +46,7 @@
4546
CodeVulnerabilityEvaluator: "code_vulnerability",
4647
CoherenceEvaluator: "coherence",
4748
ContentSafetyEvaluator: "content_safety",
49+
DocumentRetrievalEvaluator: "document_retrieval",
4850
ECIEvaluator: "eci",
4951
F1ScoreEvaluator: "f1_score",
5052
FluencyEvaluator: "fluency",

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_document_retrieval/_document_retrieval.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import math
55
import operator
66
from itertools import starmap
7-
from typing import Dict, List, TypedDict, Tuple, Optional
7+
from typing import Any, Dict, List, TypedDict, Tuple, Optional, Union
88
from azure.ai.evaluation._evaluators._common import EvaluatorBase
99
from azure.ai.evaluation._exceptions import EvaluationException
1010
from typing_extensions import override, overload
@@ -56,7 +56,13 @@ def __init__(
5656
*,
5757
ground_truth_label_min: int = 0,
5858
ground_truth_label_max: int = 4,
59-
threshold: Optional[dict] = None,
59+
ndcg_threshold: Optional[float] = 0.5,
60+
xdcg_threshold: Optional[float] = 50.0,
61+
fidelity_threshold: Optional[float] = 0.5,
62+
top1_relevance_threshold: Optional[float] = 50.0,
63+
top3_max_relevance_threshold: Optional[float] = 50.0,
64+
total_retrieved_documents_threshold: Optional[int] = 50,
65+
total_ground_truth_documents_threshold: Optional[int] = 50
6066
):
6167
super().__init__()
6268
self.k = 3
@@ -81,27 +87,19 @@ def __init__(
8187
self.ground_truth_label_max = ground_truth_label_max
8288

8389
# The default threshold for metrics where higher numbers are better.
84-
self._threshold_metrics = {
85-
"ndcg@3": 0.5,
86-
"xdcg@3": 0.5,
87-
"fidelity": 0.5,
88-
"top1_relevance": 50,
89-
"top3_max_relevance": 50,
90-
"total_retrieved_documents": 50,
91-
"total_ground_truth_documents": 50,
90+
self._threshold_metrics: Dict[str, Any] = {
91+
"ndcg@3": ndcg_threshold,
92+
"xdcg@3": xdcg_threshold,
93+
"fidelity": fidelity_threshold,
94+
"top1_relevance": top1_relevance_threshold,
95+
"top3_max_relevance": top3_max_relevance_threshold,
96+
"total_retrieved_documents": total_retrieved_documents_threshold,
97+
"total_ground_truth_documents": total_ground_truth_documents_threshold,
9298
}
9399

94100
# Ideally, the number of holes should be zero.
95101
self._threshold_holes = {"holes": 0, "holes_ratio": 0}
96102

97-
if threshold and not isinstance(threshold, dict):
98-
raise EvaluationException(
99-
f"Threshold must be a dictionary, got {type(threshold)}"
100-
)
101-
102-
elif isinstance(threshold, dict):
103-
self._threshold_metrics.update(threshold)
104-
105103
def _compute_holes(self, actual_docs: List[str], labeled_docs: List[str]) -> int:
106104
"""
107105
The number of documents retrieved from a search query which have no provided ground-truth label.
@@ -224,22 +222,16 @@ def calculate_weighted_sum_by_rating(labels: List[int]) -> float:
224222
return weighted_sum_by_rating_results / float(weighted_sum_by_rating_index)
225223

226224
def _get_binary_result(self, **metrics) -> Dict[str, float]:
227-
result = {}
225+
result: Dict[str, Any] = {}
228226

229227
for metric_name, metric_value in metrics.items():
230228
if metric_name in self._threshold_metrics.keys():
231-
result[f"{metric_name}_result"] = (
232-
metric_value >= self._threshold_metrics[metric_name]
233-
)
234-
result[f"{metric_name}_threshold"] = self._threshold_metrics[
235-
metric_name
236-
]
229+
result[f"{metric_name}_result"] = "pass" if metric_value >= self._threshold_metrics[metric_name] else "fail"
230+
result[f"{metric_name}_threshold"] = self._threshold_metrics[metric_name]
237231
result[f"{metric_name}_higher_is_better"] = True
238232

239233
elif metric_name in self._threshold_holes.keys():
240-
result[f"{metric_name}_result"] = (
241-
metric_value <= self._threshold_holes[metric_name]
242-
)
234+
result[f"{metric_name}_result"] = "pass" if metric_value <= self._threshold_holes[metric_name] else "fail"
243235
result[f"{metric_name}_threshold"] = self._threshold_holes[metric_name]
244236
result[f"{metric_name}_higher_is_better"] = False
245237

sdk/evaluation/azure-ai-evaluation/tests/unittests/test_document_retrieval_evaluator.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,28 +116,27 @@ def test_incorrect_groundtruth_max():
116116
exc_info._excinfo[1]
117117
)
118118

119-
def test_threshold(doc_retrieval_eval_data):
119+
def test_thresholds(doc_retrieval_eval_data):
120120
_, records = doc_retrieval_eval_data
121121
record = records[-1]
122122
custom_threshold_subset = {
123-
"ndcg@3": 0.7,
124-
"xdcg@3": 0.7,
125-
"fidelity": 0.7,
123+
"ndcg_threshold": 0.7,
124+
"xdcg_threshold": 0.7,
125+
"fidelity_threshold": 0.7,
126126
}
127127

128128
custom_threshold_superset = {
129-
"ndcg@3": 0.7,
130-
"xdcg@3": 0.7,
131-
"fidelity": 0.7,
132-
"top1_relevance": 70,
133-
"top3_max_relevance": 70,
134-
"total_retrieved_documents": 10,
135-
"total_ground_truth_documents": 10,
136-
"unknown_metric": 50
129+
"ndcg_threshold": 0.7,
130+
"xdcg_threshold": 0.7,
131+
"fidelity_threshold": 0.7,
132+
"top1_relevance_threshold": 70,
133+
"top3_max_relevance_threshold": 70,
134+
"total_retrieved_documents_threshold": 10,
135+
"total_ground_truth_documents_threshold": 10
137136
}
138137

139138
for threshold in [custom_threshold_subset, custom_threshold_superset]:
140-
evaluator = DocumentRetrievalEvaluator(ground_truth_label_min=0, ground_truth_label_max=2, threshold=threshold)
139+
evaluator = DocumentRetrievalEvaluator(ground_truth_label_min=0, ground_truth_label_max=2, **threshold)
141140
results = evaluator(**record)
142141

143142
expected_keys = [

0 commit comments

Comments
 (0)