|
1 | 1 | """Retrieval metrics.""" |
2 | 2 |
|
3 | | -from collections.abc import Callable |
4 | 3 | from typing import Any, Protocol |
5 | 4 |
|
6 | 5 | import numpy as np |
@@ -36,7 +35,7 @@ def __call__( |
36 | 35 |
|
37 | 36 |
|
38 | 37 | def _macrofy( |
39 | | - metric_fn: Callable[[npt.NDArray[Any], npt.NDArray[Any], int | None], float], |
| 38 | + metric_fn: RetrievalMetricFn, |
40 | 39 | query_labels: LABELS_VALUE_TYPE, |
41 | 40 | candidates_labels: CANDIDATE_TYPE, |
42 | 41 | k: int | None = None, |
@@ -72,7 +71,7 @@ def _macrofy( |
72 | 71 | for i in range(n_classes): |
73 | 72 | binarized_query_labels = query_labels_[..., i] |
74 | 73 | binarized_candidates_labels = candidates_labels_[..., i] |
75 | | - classwise_values.append(metric_fn(binarized_query_labels, binarized_candidates_labels, k)) |
| 74 | + classwise_values.append(metric_fn(binarized_query_labels, binarized_candidates_labels, k)) # type: ignore[arg-type] |
76 | 75 |
|
77 | 76 | return np.mean(classwise_values) # type: ignore[return-value] |
78 | 77 |
|
@@ -136,12 +135,12 @@ def retrieval_map(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE_ |
136 | 135 | :param k: Number of top items to consider for each query |
137 | 136 | :return: Score of the retrieval metric |
138 | 137 | """ |
139 | | - ap_list = [_average_precision(q, c, k) for q, c in zip(query_labels, candidates_labels, strict=True)] |
| 138 | + ap_list = [_average_precision(q, c, k) for q, c in zip(query_labels, candidates_labels, strict=True)] # type: ignore[arg-type] |
140 | 139 | return sum(ap_list) / len(ap_list) |
141 | 140 |
|
142 | 141 |
|
143 | 142 | def _average_precision_intersecting( |
144 | | - query_label: LABELS_VALUE_TYPE, candidate_labels: CANDIDATE_TYPE, k: int | None = None |
| 143 | + query_label: list[int], candidate_labels: CANDIDATE_TYPE, k: int | None = None |
145 | 144 | ) -> float: |
146 | 145 | r""" |
147 | 146 | Calculate the average precision at position k for the intersecting labels. |
@@ -212,7 +211,7 @@ def retrieval_map_intersecting( |
212 | 211 | :param k: Number of top items to consider for each query |
213 | 212 | :return: Score of the retrieval metric |
214 | 213 | """ |
215 | | - ap_list = [_average_precision_intersecting(q, c, k) for q, c in zip(query_labels, candidates_labels, strict=True)] |
| 214 | + ap_list = [_average_precision_intersecting(q, c, k) for q, c in zip(query_labels, candidates_labels, strict=True)] # type: ignore[arg-type] |
216 | 215 | return sum(ap_list) / len(ap_list) |
217 | 216 |
|
218 | 217 |
|
|
0 commit comments