Skip to content

Commit e24bde4

Browse files
committed
add ingore oos decorator
1 parent b2c8986 commit e24bde4

File tree

2 files changed

+58
-105
lines changed

2 files changed

+58
-105
lines changed

autointent/metrics/retrieval.py

Lines changed: 30 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Retrieval metrics."""
22

3+
from functools import wraps
34
from typing import Any, Protocol
45

56
import numpy as np
@@ -109,6 +110,21 @@ def _average_precision(query_label: int, candidate_labels: npt.NDArray[np.int64]
109110
return sum_precision / num_relevant if num_relevant > 0 else 0.0
110111

111112

113+
def ignore_oos(func: RetrievalMetricFn) -> RetrievalMetricFn:
114+
"""Ignore OOS in metrics calculation (decorator)."""
115+
116+
@wraps(func)
117+
def wrapper(query_labels: list[Any | None], candidates_labels: list[Any]) -> float:
118+
query_labels_filtered = [lab for lab in query_labels if lab is not None]
119+
candidates_labels_filtered = [
120+
cand for cand, lab in zip(candidates_labels, query_labels, strict=True) if lab is not None
121+
]
122+
return func(query_labels_filtered, candidates_labels_filtered)
123+
124+
return wrapper
125+
126+
127+
@ignore_oos
112128
def retrieval_map(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE_TYPE, k: int | None = None) -> float:
113129
r"""
114130
Calculate the mean average precision at position k.
@@ -180,6 +196,7 @@ class of the query :math:`q`,
180196
return sum_precision / num_relevant if num_relevant > 0 else 0.0
181197

182198

199+
@ignore_oos
183200
def retrieval_map_intersecting(
184201
query_labels: LABELS_VALUE_TYPE,
185202
candidates_labels: CANDIDATE_TYPE,
@@ -215,6 +232,7 @@ def retrieval_map_intersecting(
215232
return sum(ap_list) / len(ap_list)
216233

217234

235+
@ignore_oos
218236
def retrieval_map_macro(
219237
query_labels: LABELS_VALUE_TYPE,
220238
candidates_labels: CANDIDATE_TYPE,
@@ -235,47 +253,7 @@ def retrieval_map_macro(
235253
return _macrofy(retrieval_map, query_labels, candidates_labels, k)
236254

237255

238-
def _retrieval_map_numpy(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE_TYPE, k: int) -> float:
239-
r"""
240-
Calculate mean average precision at position k.
241-
242-
The mean average precision (MAP) at position :math:`k` is calculated as follows:
243-
244-
.. math::
245-
246-
\text{AP}_q = \frac{1}{|R_q|} \sum_{i=1}^{k} P_q(i) \cdot \mathbb{1}(y_{\text{true},q} = y_{\text{pred},i})
247-
248-
\text{MAP}@k = \frac{1}{|Q|} \sum_{q=1}^{Q} \text{AP}_q
249-
250-
where:
251-
- :math:`\text{AP}_q` is the average precision for query :math:`q`,
252-
- :math:`P_q(i)` is the precision at the :math:`i`-th position for query :math:`q`,
253-
- :math:`\mathbb{1}(y_{\text{true},q} = y_{\text{pred},i})` is the indicator function that equals
254-
1 if the true label of the query matches the predicted label at position :math:`i` and 0 otherwise,
255-
- :math:`|R_q|` is the total number of relevant items for query :math:`q`,
256-
- :math:`|Q|` is the total number of queries.
257-
258-
:param query_labels: For each query, this list contains its class labels
259-
:param candidates_labels: For each query, these lists contain class labels of items ranked by a retrieval model (from most to least relevant)
260-
:param k: Number of top items to consider for each query
261-
:return: Score of the retrieval metric
262-
""" # noqa: E501
263-
query_label_, candidates_labels_ = transform(query_labels, candidates_labels)
264-
candidates_labels_ = candidates_labels_[:, :k]
265-
relevance_mask = candidates_labels_ == query_label_[:, None]
266-
cumulative_relevant = np.cumsum(relevance_mask, axis=1)
267-
precision_at_k = cumulative_relevant * relevance_mask / np.arange(1, k + 1)
268-
sum_precision = np.sum(precision_at_k, axis=1)
269-
num_relevant = np.sum(relevance_mask, axis=1)
270-
average_precision = np.divide(
271-
sum_precision,
272-
num_relevant,
273-
out=np.zeros_like(sum_precision),
274-
where=num_relevant != 0,
275-
)
276-
return np.mean(average_precision) # type: ignore[no-any-return]
277-
278-
256+
@ignore_oos
279257
def retrieval_hit_rate(
280258
query_labels: LABELS_VALUE_TYPE,
281259
candidates_labels: CANDIDATE_TYPE,
@@ -315,6 +293,7 @@ def retrieval_hit_rate(
315293
return float(hit_count / num_queries)
316294

317295

296+
@ignore_oos
318297
def retrieval_hit_rate_intersecting(
319298
query_labels: LABELS_VALUE_TYPE,
320299
candidates_labels: CANDIDATE_TYPE,
@@ -360,6 +339,7 @@ def retrieval_hit_rate_intersecting(
360339
return float(hit_count / num_queries)
361340

362341

342+
@ignore_oos
363343
def retrieval_hit_rate_macro(
364344
query_labels: LABELS_VALUE_TYPE,
365345
candidates_labels: CANDIDATE_TYPE,
@@ -380,34 +360,7 @@ def retrieval_hit_rate_macro(
380360
return _macrofy(retrieval_hit_rate, query_labels, candidates_labels, k)
381361

382362

383-
def _retrieval_hit_rate_numpy(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE_TYPE, k: int) -> float:
384-
r"""
385-
Calculate the hit rate at position k.
386-
387-
The hit rate is calculated as:
388-
389-
.. math::
390-
391-
\text{Hit Rate} = \frac{\sum_{i=1}^N \mathbb{1}(y_{\text{query},i} \in y_{\text{candidates},i}^{(1:k)})}{N}
392-
393-
where:
394-
- :math:`N` is the total number of queries,
395-
- :math:`y_{\text{query},i}` is the true label for the :math:`i`-th query,
396-
- :math:`y_{\text{candidates},i}^{(1:k)}` is the set of top-k predicted labels for the :math:`i`-th query,
397-
- :math:`\mathbb{1}(\text{condition})` is the indicator function that equals 1 if the condition
398-
is true and 0 otherwise.
399-
400-
:param query_labels: For each query, this list contains its class labels
401-
:param candidates_labels: For each query, these lists contain class labels of items ranked by a retrieval model (from most to least relevant)
402-
:param k: Number of top items to consider for each query
403-
:return: Score of the retrieval metric
404-
""" # noqa: E501
405-
query_label_, candidates_labels_ = transform(query_labels, candidates_labels)
406-
truncated_candidates = candidates_labels_[:, :k]
407-
hit_mask = np.isin(query_label_[:, None], truncated_candidates).any(axis=1)
408-
return hit_mask.mean() # type: ignore[no-any-return]
409-
410-
363+
@ignore_oos
411364
def retrieval_precision(
412365
query_labels: LABELS_VALUE_TYPE,
413366
candidates_labels: CANDIDATE_TYPE,
@@ -449,6 +402,7 @@ def retrieval_precision(
449402
return float(total_precision / num_queries)
450403

451404

405+
@ignore_oos
452406
def retrieval_precision_intersecting(
453407
query_labels: LABELS_VALUE_TYPE,
454408
candidates_labels: CANDIDATE_TYPE,
@@ -496,6 +450,7 @@ def retrieval_precision_intersecting(
496450
return float(total_precision / num_queries)
497451

498452

453+
@ignore_oos
499454
def retrieval_precision_macro(
500455
query_labels: LABELS_VALUE_TYPE,
501456
candidates_labels: CANDIDATE_TYPE,
@@ -516,41 +471,6 @@ def retrieval_precision_macro(
516471
return _macrofy(retrieval_precision, query_labels, candidates_labels, k)
517472

518473

519-
def _retrieval_precision_numpy(
520-
query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE_TYPE, k: int | None = None
521-
) -> float:
522-
r"""
523-
Calculate the precision at position k.
524-
525-
Precision at position :math:`k` is calculated as:
526-
527-
.. math::
528-
529-
\text{Precision@k} = \frac{1}{N} \sum_{i=1}^N \frac{\sum_{j=1}^k
530-
\mathbb{1}(y_{\text{query},i} = y_{\text{candidates},i,j})}{k}
531-
532-
where:
533-
- :math:`N` is the total number of queries,
534-
- :math:`y_{\text{query},i}` is the true label for the :math:`i`-th query,
535-
- :math:`y_{\text{candidates},i,j}` is the :math:`j`-th predicted label for the :math:`i`-th query,
536-
- :math:`\mathbb{1}(\text{condition})` is the indicator function that equals 1 if the
537-
condition is true and 0 otherwise,
538-
- :math:`k` is the number of top candidates considered.
539-
540-
:param query_labels: For each query, this list contains its class labels
541-
:param candidates_labels: For each query, these lists contain class labels of items ranked by a retrieval model
542-
(from most to least relevant)
543-
:param k: Number of top items to consider for each query
544-
:return: Score of the retrieval metric
545-
"""
546-
query_label_, candidates_labels_ = transform(query_labels, candidates_labels)
547-
top_k_candidates = candidates_labels_[:, :k]
548-
matches = (top_k_candidates == query_label_[:, None]).astype(int)
549-
relevant_counts = np.sum(matches, axis=1)
550-
precision_at_k = relevant_counts / k
551-
return np.mean(precision_at_k) # type: ignore[no-any-return]
552-
553-
554474
def _dcg(relevance_scores: npt.NDArray[Any], k: int | None = None) -> float:
555475
r"""
556476
Calculate the Discounted Cumulative Gain (DCG) at position k.
@@ -597,6 +517,7 @@ def _idcg(relevance_scores: npt.NDArray[Any], k: int | None = None) -> float:
597517
return _dcg(ideal_scores, k)
598518

599519

520+
@ignore_oos
600521
def retrieval_ndcg(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE_TYPE, k: int | None = None) -> float:
601522
r"""
602523
Calculate the Normalized Discounted Cumulative Gain (NDCG) at position k.
@@ -632,6 +553,7 @@ def retrieval_ndcg(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE
632553
return float(np.mean(ndcg_scores))
633554

634555

556+
@ignore_oos
635557
def retrieval_ndcg_intersecting(
636558
query_labels: LABELS_VALUE_TYPE,
637559
candidates_labels: CANDIDATE_TYPE,
@@ -674,6 +596,7 @@ def retrieval_ndcg_intersecting(
674596
return np.mean(ndcg_scores) # type: ignore[return-value]
675597

676598

599+
@ignore_oos
677600
def retrieval_ndcg_macro(
678601
query_labels: LABELS_VALUE_TYPE,
679602
candidates_labels: CANDIDATE_TYPE,
@@ -692,6 +615,7 @@ def retrieval_ndcg_macro(
692615
return _macrofy(retrieval_ndcg, query_labels, candidates_labels, k)
693616

694617

618+
@ignore_oos
695619
def retrieval_mrr(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE_TYPE, k: int | None = None) -> float:
696620
r"""
697621
Calculate the Mean Reciprocal Rank (MRR) at position k.
@@ -726,6 +650,7 @@ def retrieval_mrr(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE_
726650
return float(mrr_sum / num_queries)
727651

728652

653+
@ignore_oos
729654
def retrieval_mrr_intersecting(
730655
query_labels: LABELS_VALUE_TYPE,
731656
candidates_labels: CANDIDATE_TYPE,
@@ -766,6 +691,7 @@ def retrieval_mrr_intersecting(
766691
return float(mrr_sum / num_queries)
767692

768693

694+
@ignore_oos
769695
def retrieval_mrr_macro(
770696
query_labels: LABELS_VALUE_TYPE,
771697
candidates_labels: CANDIDATE_TYPE,

autointent/metrics/scoring.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Scoring metrics for multiclass and multilabel classification tasks."""
22

33
import logging
4-
from typing import Protocol
4+
from functools import wraps
5+
from typing import Any, Protocol
56

67
import numpy as np
78
from sklearn.metrics import coverage_error, label_ranking_average_precision_score, label_ranking_loss, roc_auc_score
@@ -29,6 +30,23 @@ def __call__(self, labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> floa
2930
...
3031

3132

33+
34+
35+
def ignore_oos(func: ScoringMetricFn) -> ScoringMetricFn:
36+
"""Ignore OOS in metrics calculation (decorator)."""
37+
38+
@wraps(func)
39+
def wrapper(labels: list[Any | None], scores: list[Any]) -> float:
40+
labels_filtered = [lab for lab in labels if lab is not None]
41+
scores_filtered = [
42+
score for score, lab in zip(scores, labels, strict=True) if lab is not None
43+
]
44+
return func(labels_filtered, scores_filtered)
45+
46+
return wrapper
47+
48+
49+
@ignore_oos
3250
def scoring_log_likelihood(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE, eps: float = 1e-10) -> float:
3351
r"""
3452
Supports multiclass and multilabel cases.
@@ -75,6 +93,7 @@ def scoring_log_likelihood(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE,
7593
return round(float(res), 6)
7694

7795

96+
@ignore_oos
7897
def scoring_roc_auc(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
7998
r"""
8099
Supports multiclass and multilabel cases.
@@ -126,6 +145,7 @@ def _calculate_decision_metric(
126145
return res
127146

128147

148+
@ignore_oos
129149
def scoring_accuracy(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
130150
r"""
131151
Calculate accuracy for multiclass and multilabel classification.
@@ -140,6 +160,7 @@ def scoring_accuracy(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> fl
140160
return _calculate_decision_metric(decision_accuracy, labels, scores)
141161

142162

163+
@ignore_oos
143164
def scoring_f1(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
144165
r"""
145166
Calculate the F1 score for multiclass and multilabel classification.
@@ -154,6 +175,7 @@ def scoring_f1(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
154175
return _calculate_decision_metric(decision_f1, labels, scores)
155176

156177

178+
@ignore_oos
157179
def scoring_precision(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
158180
r"""
159181
Calculate precision for multiclass and multilabel classification.
@@ -168,6 +190,7 @@ def scoring_precision(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> f
168190
return _calculate_decision_metric(decision_precision, labels, scores)
169191

170192

193+
@ignore_oos
171194
def scoring_recall(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
172195
r"""
173196
Calculate recall for multiclass and multilabel classification.
@@ -182,6 +205,7 @@ def scoring_recall(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> floa
182205
return _calculate_decision_metric(decision_recall, labels, scores)
183206

184207

208+
@ignore_oos
185209
def scoring_hit_rate(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
186210
r"""
187211
Calculate the hit rate for multilabel classification.
@@ -210,6 +234,7 @@ def scoring_hit_rate(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> fl
210234
return float(np.mean(is_in))
211235

212236

237+
@ignore_oos
213238
def scoring_neg_coverage(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
214239
"""
215240
Supports multilabel classification.
@@ -246,6 +271,7 @@ def scoring_neg_coverage(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -
246271
return float(1 - (coverage_error(labels, scores) - 1) / (n_classes - 1))
247272

248273

274+
@ignore_oos
249275
def scoring_neg_ranking_loss(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
250276
"""
251277
Supports multilabel.
@@ -262,6 +288,7 @@ def scoring_neg_ranking_loss(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYP
262288
return float(-label_ranking_loss(labels, scores))
263289

264290

291+
@ignore_oos
265292
def scoring_map(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
266293
r"""
267294
Calculate the mean average precision (MAP) score for multilabel classification.

0 commit comments

Comments
 (0)