11"""Retrieval metrics."""
22
3+ from functools import wraps
34from typing import Any , Protocol
45
56import numpy as np
@@ -109,6 +110,21 @@ def _average_precision(query_label: int, candidate_labels: npt.NDArray[np.int64]
109110 return sum_precision / num_relevant if num_relevant > 0 else 0.0
110111
111112
113+ def ignore_oos (func : RetrievalMetricFn ) -> RetrievalMetricFn :
114+ """Ignore OOS in metrics calculation (decorator)."""
115+
116+ @wraps (func )
117+ def wrapper (query_labels : list [Any | None ], candidates_labels : list [Any ]) -> float :
118+ query_labels_filtered = [lab for lab in query_labels if lab is not None ]
119+ candidates_labels_filtered = [
120+ cand for cand , lab in zip (candidates_labels , query_labels , strict = True ) if lab is not None
121+ ]
122+ return func (query_labels_filtered , candidates_labels_filtered )
123+
124+ return wrapper
125+
126+
127+ @ignore_oos
112128def retrieval_map (query_labels : LABELS_VALUE_TYPE , candidates_labels : CANDIDATE_TYPE , k : int | None = None ) -> float :
113129 r"""
114130 Calculate the mean average precision at position k.
@@ -180,6 +196,7 @@ class of the query :math:`q`,
180196 return sum_precision / num_relevant if num_relevant > 0 else 0.0
181197
182198
199+ @ignore_oos
183200def retrieval_map_intersecting (
184201 query_labels : LABELS_VALUE_TYPE ,
185202 candidates_labels : CANDIDATE_TYPE ,
@@ -215,6 +232,7 @@ def retrieval_map_intersecting(
215232 return sum (ap_list ) / len (ap_list )
216233
217234
235+ @ignore_oos
218236def retrieval_map_macro (
219237 query_labels : LABELS_VALUE_TYPE ,
220238 candidates_labels : CANDIDATE_TYPE ,
@@ -235,47 +253,7 @@ def retrieval_map_macro(
235253 return _macrofy (retrieval_map , query_labels , candidates_labels , k )
236254
237255
238- def _retrieval_map_numpy (query_labels : LABELS_VALUE_TYPE , candidates_labels : CANDIDATE_TYPE , k : int ) -> float :
239- r"""
240- Calculate mean average precision at position k.
241-
242- The mean average precision (MAP) at position :math:`k` is calculated as follows:
243-
244- .. math::
245-
246- \text{AP}_q = \frac{1}{|R_q|} \sum_{i=1}^{k} P_q(i) \cdot \mathbb{1}(y_{\text{true},q} = y_{\text{pred},i})
247-
248- \text{MAP}@k = \frac{1}{|Q|} \sum_{q=1}^{Q} \text{AP}_q
249-
250- where:
251- - :math:`\text{AP}_q` is the average precision for query :math:`q`,
252- - :math:`P_q(i)` is the precision at the :math:`i`-th position for query :math:`q`,
253- - :math:`\mathbb{1}(y_{\text{true},q} = y_{\text{pred},i})` is the indicator function that equals
254- 1 if the true label of the query matches the predicted label at position :math:`i` and 0 otherwise,
255- - :math:`|R_q|` is the total number of relevant items for query :math:`q`,
256- - :math:`|Q|` is the total number of queries.
257-
258- :param query_labels: For each query, this list contains its class labels
259- :param candidates_labels: For each query, these lists contain class labels of items ranked by a retrieval model (from most to least relevant)
260- :param k: Number of top items to consider for each query
261- :return: Score of the retrieval metric
262- """ # noqa: E501
263- query_label_ , candidates_labels_ = transform (query_labels , candidates_labels )
264- candidates_labels_ = candidates_labels_ [:, :k ]
265- relevance_mask = candidates_labels_ == query_label_ [:, None ]
266- cumulative_relevant = np .cumsum (relevance_mask , axis = 1 )
267- precision_at_k = cumulative_relevant * relevance_mask / np .arange (1 , k + 1 )
268- sum_precision = np .sum (precision_at_k , axis = 1 )
269- num_relevant = np .sum (relevance_mask , axis = 1 )
270- average_precision = np .divide (
271- sum_precision ,
272- num_relevant ,
273- out = np .zeros_like (sum_precision ),
274- where = num_relevant != 0 ,
275- )
276- return np .mean (average_precision ) # type: ignore[no-any-return]
277-
278-
256+ @ignore_oos
279257def retrieval_hit_rate (
280258 query_labels : LABELS_VALUE_TYPE ,
281259 candidates_labels : CANDIDATE_TYPE ,
@@ -315,6 +293,7 @@ def retrieval_hit_rate(
315293 return float (hit_count / num_queries )
316294
317295
296+ @ignore_oos
318297def retrieval_hit_rate_intersecting (
319298 query_labels : LABELS_VALUE_TYPE ,
320299 candidates_labels : CANDIDATE_TYPE ,
@@ -360,6 +339,7 @@ def retrieval_hit_rate_intersecting(
360339 return float (hit_count / num_queries )
361340
362341
342+ @ignore_oos
363343def retrieval_hit_rate_macro (
364344 query_labels : LABELS_VALUE_TYPE ,
365345 candidates_labels : CANDIDATE_TYPE ,
@@ -380,34 +360,7 @@ def retrieval_hit_rate_macro(
380360 return _macrofy (retrieval_hit_rate , query_labels , candidates_labels , k )
381361
382362
383- def _retrieval_hit_rate_numpy (query_labels : LABELS_VALUE_TYPE , candidates_labels : CANDIDATE_TYPE , k : int ) -> float :
384- r"""
385- Calculate the hit rate at position k.
386-
387- The hit rate is calculated as:
388-
389- .. math::
390-
391- \text{Hit Rate} = \frac{\sum_{i=1}^N \mathbb{1}(y_{\text{query},i} \in y_{\text{candidates},i}^{(1:k)})}{N}
392-
393- where:
394- - :math:`N` is the total number of queries,
395- - :math:`y_{\text{query},i}` is the true label for the :math:`i`-th query,
396- - :math:`y_{\text{candidates},i}^{(1:k)}` is the set of top-k predicted labels for the :math:`i`-th query,
397- - :math:`\mathbb{1}(\text{condition})` is the indicator function that equals 1 if the condition
398- is true and 0 otherwise.
399-
400- :param query_labels: For each query, this list contains its class labels
401- :param candidates_labels: For each query, these lists contain class labels of items ranked by a retrieval model (from most to least relevant)
402- :param k: Number of top items to consider for each query
403- :return: Score of the retrieval metric
404- """ # noqa: E501
405- query_label_ , candidates_labels_ = transform (query_labels , candidates_labels )
406- truncated_candidates = candidates_labels_ [:, :k ]
407- hit_mask = np .isin (query_label_ [:, None ], truncated_candidates ).any (axis = 1 )
408- return hit_mask .mean () # type: ignore[no-any-return]
409-
410-
363+ @ignore_oos
411364def retrieval_precision (
412365 query_labels : LABELS_VALUE_TYPE ,
413366 candidates_labels : CANDIDATE_TYPE ,
@@ -449,6 +402,7 @@ def retrieval_precision(
449402 return float (total_precision / num_queries )
450403
451404
405+ @ignore_oos
452406def retrieval_precision_intersecting (
453407 query_labels : LABELS_VALUE_TYPE ,
454408 candidates_labels : CANDIDATE_TYPE ,
@@ -496,6 +450,7 @@ def retrieval_precision_intersecting(
496450 return float (total_precision / num_queries )
497451
498452
453+ @ignore_oos
499454def retrieval_precision_macro (
500455 query_labels : LABELS_VALUE_TYPE ,
501456 candidates_labels : CANDIDATE_TYPE ,
@@ -516,41 +471,6 @@ def retrieval_precision_macro(
516471 return _macrofy (retrieval_precision , query_labels , candidates_labels , k )
517472
518473
519- def _retrieval_precision_numpy (
520- query_labels : LABELS_VALUE_TYPE , candidates_labels : CANDIDATE_TYPE , k : int | None = None
521- ) -> float :
522- r"""
523- Calculate the precision at position k.
524-
525- Precision at position :math:`k` is calculated as:
526-
527- .. math::
528-
529- \text{Precision@k} = \frac{1}{N} \sum_{i=1}^N \frac{\sum_{j=1}^k
530- \mathbb{1}(y_{\text{query},i} = y_{\text{candidates},i,j})}{k}
531-
532- where:
533- - :math:`N` is the total number of queries,
534- - :math:`y_{\text{query},i}` is the true label for the :math:`i`-th query,
535- - :math:`y_{\text{candidates},i,j}` is the :math:`j`-th predicted label for the :math:`i`-th query,
536- - :math:`\mathbb{1}(\text{condition})` is the indicator function that equals 1 if the
537- condition is true and 0 otherwise,
538- - :math:`k` is the number of top candidates considered.
539-
540- :param query_labels: For each query, this list contains its class labels
541- :param candidates_labels: For each query, these lists contain class labels of items ranked by a retrieval model
542- (from most to least relevant)
543- :param k: Number of top items to consider for each query
544- :return: Score of the retrieval metric
545- """
546- query_label_ , candidates_labels_ = transform (query_labels , candidates_labels )
547- top_k_candidates = candidates_labels_ [:, :k ]
548- matches = (top_k_candidates == query_label_ [:, None ]).astype (int )
549- relevant_counts = np .sum (matches , axis = 1 )
550- precision_at_k = relevant_counts / k
551- return np .mean (precision_at_k ) # type: ignore[no-any-return]
552-
553-
554474def _dcg (relevance_scores : npt .NDArray [Any ], k : int | None = None ) -> float :
555475 r"""
556476 Calculate the Discounted Cumulative Gain (DCG) at position k.
@@ -597,6 +517,7 @@ def _idcg(relevance_scores: npt.NDArray[Any], k: int | None = None) -> float:
597517 return _dcg (ideal_scores , k )
598518
599519
520+ @ignore_oos
600521def retrieval_ndcg (query_labels : LABELS_VALUE_TYPE , candidates_labels : CANDIDATE_TYPE , k : int | None = None ) -> float :
601522 r"""
602523 Calculate the Normalized Discounted Cumulative Gain (NDCG) at position k.
@@ -632,6 +553,7 @@ def retrieval_ndcg(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE
632553 return float (np .mean (ndcg_scores ))
633554
634555
556+ @ignore_oos
635557def retrieval_ndcg_intersecting (
636558 query_labels : LABELS_VALUE_TYPE ,
637559 candidates_labels : CANDIDATE_TYPE ,
@@ -674,6 +596,7 @@ def retrieval_ndcg_intersecting(
674596 return np .mean (ndcg_scores ) # type: ignore[return-value]
675597
676598
599+ @ignore_oos
677600def retrieval_ndcg_macro (
678601 query_labels : LABELS_VALUE_TYPE ,
679602 candidates_labels : CANDIDATE_TYPE ,
@@ -692,6 +615,7 @@ def retrieval_ndcg_macro(
692615 return _macrofy (retrieval_ndcg , query_labels , candidates_labels , k )
693616
694617
618+ @ignore_oos
695619def retrieval_mrr (query_labels : LABELS_VALUE_TYPE , candidates_labels : CANDIDATE_TYPE , k : int | None = None ) -> float :
696620 r"""
697621 Calculate the Mean Reciprocal Rank (MRR) at position k.
@@ -726,6 +650,7 @@ def retrieval_mrr(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE_
726650 return float (mrr_sum / num_queries )
727651
728652
653+ @ignore_oos
729654def retrieval_mrr_intersecting (
730655 query_labels : LABELS_VALUE_TYPE ,
731656 candidates_labels : CANDIDATE_TYPE ,
@@ -766,6 +691,7 @@ def retrieval_mrr_intersecting(
766691 return float (mrr_sum / num_queries )
767692
768693
694+ @ignore_oos
769695def retrieval_mrr_macro (
770696 query_labels : LABELS_VALUE_TYPE ,
771697 candidates_labels : CANDIDATE_TYPE ,
0 commit comments