44import math
55import operator
66from itertools import starmap
7- from typing import Dict , List , TypedDict , Tuple , Optional
7+ from typing import Any , Dict , List , TypedDict , Tuple , Optional , Union
88from azure .ai .evaluation ._evaluators ._common import EvaluatorBase
99from azure .ai .evaluation ._exceptions import EvaluationException
1010from typing_extensions import override , overload
@@ -56,7 +56,13 @@ def __init__(
5656 * ,
5757 ground_truth_label_min : int = 0 ,
5858 ground_truth_label_max : int = 4 ,
59- threshold : Optional [dict ] = None ,
59+ ndcg_threshold : Optional [float ] = 0.5 ,
60+ xdcg_threshold : Optional [float ] = 50.0 ,
61+ fidelity_threshold : Optional [float ] = 0.5 ,
62+ top1_relevance_threshold : Optional [float ] = 50.0 ,
63+ top3_max_relevance_threshold : Optional [float ] = 50.0 ,
64+ total_retrieved_documents_threshold : Optional [int ] = 50 ,
65+ total_ground_truth_documents_threshold : Optional [int ] = 50
6066 ):
6167 super ().__init__ ()
6268 self .k = 3
@@ -81,27 +87,19 @@ def __init__(
8187 self .ground_truth_label_max = ground_truth_label_max
8288
8389 # The default threshold for metrics where higher numbers are better.
84- self ._threshold_metrics = {
85- "ndcg@3" : 0.5 ,
86- "xdcg@3" : 0.5 ,
87- "fidelity" : 0.5 ,
88- "top1_relevance" : 50 ,
89- "top3_max_relevance" : 50 ,
90- "total_retrieved_documents" : 50 ,
91- "total_ground_truth_documents" : 50 ,
90+ self ._threshold_metrics : Dict [ str , Any ] = {
91+ "ndcg@3" : ndcg_threshold ,
92+ "xdcg@3" : xdcg_threshold ,
93+ "fidelity" : fidelity_threshold ,
94+ "top1_relevance" : top1_relevance_threshold ,
95+ "top3_max_relevance" : top3_max_relevance_threshold ,
96+ "total_retrieved_documents" : total_retrieved_documents_threshold ,
97+ "total_ground_truth_documents" : total_ground_truth_documents_threshold ,
9298 }
9399
94100 # Ideally, the number of holes should be zero.
95101 self ._threshold_holes = {"holes" : 0 , "holes_ratio" : 0 }
96102
97- if threshold and not isinstance (threshold , dict ):
98- raise EvaluationException (
99- f"Threshold must be a dictionary, got { type (threshold )} "
100- )
101-
102- elif isinstance (threshold , dict ):
103- self ._threshold_metrics .update (threshold )
104-
105103 def _compute_holes (self , actual_docs : List [str ], labeled_docs : List [str ]) -> int :
106104 """
107105 The number of documents retrieved from a search query which have no provided ground-truth label.
@@ -224,22 +222,16 @@ def calculate_weighted_sum_by_rating(labels: List[int]) -> float:
224222 return weighted_sum_by_rating_results / float (weighted_sum_by_rating_index )
225223
226224 def _get_binary_result (self , ** metrics ) -> Dict [str , float ]:
227- result = {}
225+ result : Dict [ str , Any ] = {}
228226
229227 for metric_name , metric_value in metrics .items ():
230228 if metric_name in self ._threshold_metrics .keys ():
231- result [f"{ metric_name } _result" ] = (
232- metric_value >= self ._threshold_metrics [metric_name ]
233- )
234- result [f"{ metric_name } _threshold" ] = self ._threshold_metrics [
235- metric_name
236- ]
229+ result [f"{ metric_name } _result" ] = "pass" if metric_value >= self ._threshold_metrics [metric_name ] else "fail"
230+ result [f"{ metric_name } _threshold" ] = self ._threshold_metrics [metric_name ]
237231 result [f"{ metric_name } _higher_is_better" ] = True
238232
239233 elif metric_name in self ._threshold_holes .keys ():
240- result [f"{ metric_name } _result" ] = (
241- metric_value <= self ._threshold_holes [metric_name ]
242- )
234+ result [f"{ metric_name } _result" ] = "pass" if metric_value <= self ._threshold_holes [metric_name ] else "fail"
243235 result [f"{ metric_name } _threshold" ] = self ._threshold_holes [metric_name ]
244236 result [f"{ metric_name } _higher_is_better" ] = False
245237
0 commit comments