4
4
import math
5
5
import operator
6
6
from itertools import starmap
7
- from typing import Dict , List , TypedDict , Tuple , Optional
7
+ from typing import Any , Dict , List , TypedDict , Tuple , Optional , Union
8
8
from azure .ai .evaluation ._evaluators ._common import EvaluatorBase
9
9
from azure .ai .evaluation ._exceptions import EvaluationException
10
10
from typing_extensions import override , overload
@@ -56,7 +56,13 @@ def __init__(
56
56
* ,
57
57
ground_truth_label_min : int = 0 ,
58
58
ground_truth_label_max : int = 4 ,
59
- threshold : Optional [dict ] = None ,
59
+ ndcg_threshold : Optional [float ] = 0.5 ,
60
+ xdcg_threshold : Optional [float ] = 50.0 ,
61
+ fidelity_threshold : Optional [float ] = 0.5 ,
62
+ top1_relevance_threshold : Optional [float ] = 50.0 ,
63
+ top3_max_relevance_threshold : Optional [float ] = 50.0 ,
64
+ total_retrieved_documents_threshold : Optional [int ] = 50 ,
65
+ total_ground_truth_documents_threshold : Optional [int ] = 50
60
66
):
61
67
super ().__init__ ()
62
68
self .k = 3
@@ -81,27 +87,19 @@ def __init__(
81
87
self .ground_truth_label_max = ground_truth_label_max
82
88
83
89
# The default threshold for metrics where higher numbers are better.
84
- self ._threshold_metrics = {
85
- "ndcg@3" : 0.5 ,
86
- "xdcg@3" : 0.5 ,
87
- "fidelity" : 0.5 ,
88
- "top1_relevance" : 50 ,
89
- "top3_max_relevance" : 50 ,
90
- "total_retrieved_documents" : 50 ,
91
- "total_ground_truth_documents" : 50 ,
90
+ self ._threshold_metrics : Dict [ str , Any ] = {
91
+ "ndcg@3" : ndcg_threshold ,
92
+ "xdcg@3" : xdcg_threshold ,
93
+ "fidelity" : fidelity_threshold ,
94
+ "top1_relevance" : top1_relevance_threshold ,
95
+ "top3_max_relevance" : top3_max_relevance_threshold ,
96
+ "total_retrieved_documents" : total_retrieved_documents_threshold ,
97
+ "total_ground_truth_documents" : total_ground_truth_documents_threshold ,
92
98
}
93
99
94
100
# Ideally, the number of holes should be zero.
95
101
self ._threshold_holes = {"holes" : 0 , "holes_ratio" : 0 }
96
102
97
- if threshold and not isinstance (threshold , dict ):
98
- raise EvaluationException (
99
- f"Threshold must be a dictionary, got { type (threshold )} "
100
- )
101
-
102
- elif isinstance (threshold , dict ):
103
- self ._threshold_metrics .update (threshold )
104
-
105
103
def _compute_holes (self , actual_docs : List [str ], labeled_docs : List [str ]) -> int :
106
104
"""
107
105
The number of documents retrieved from a search query which have no provided ground-truth label.
@@ -224,22 +222,16 @@ def calculate_weighted_sum_by_rating(labels: List[int]) -> float:
224
222
return weighted_sum_by_rating_results / float (weighted_sum_by_rating_index )
225
223
226
224
def _get_binary_result (self , ** metrics ) -> Dict [str , float ]:
227
- result = {}
225
+ result : Dict [ str , Any ] = {}
228
226
229
227
for metric_name , metric_value in metrics .items ():
230
228
if metric_name in self ._threshold_metrics .keys ():
231
- result [f"{ metric_name } _result" ] = (
232
- metric_value >= self ._threshold_metrics [metric_name ]
233
- )
234
- result [f"{ metric_name } _threshold" ] = self ._threshold_metrics [
235
- metric_name
236
- ]
229
+ result [f"{ metric_name } _result" ] = "pass" if metric_value >= self ._threshold_metrics [metric_name ] else "fail"
230
+ result [f"{ metric_name } _threshold" ] = self ._threshold_metrics [metric_name ]
237
231
result [f"{ metric_name } _higher_is_better" ] = True
238
232
239
233
elif metric_name in self ._threshold_holes .keys ():
240
- result [f"{ metric_name } _result" ] = (
241
- metric_value <= self ._threshold_holes [metric_name ]
242
- )
234
+ result [f"{ metric_name } _result" ] = "pass" if metric_value <= self ._threshold_holes [metric_name ] else "fail"
243
235
result [f"{ metric_name } _threshold" ] = self ._threshold_holes [metric_name ]
244
236
result [f"{ metric_name } _higher_is_better" ] = False
245
237
0 commit comments