88import numpy .typing as npt
99from sklearn .metrics import f1_score , precision_score , recall_score , roc_auc_score
1010
11- from autointent .custom_types import LabelType
11+ from autointent .custom_types import ListOfGenericLabels , ListOfLabels
1212
1313from ._converter import transform
14- from .custom_types import LABELS_VALUE_TYPE
1514
1615logger = logging .getLogger (__name__ )
1716
1817
1918class DecisionMetricFn (Protocol ):
2019 """Protocol for decision metrics."""
2120
22- def __call__ (self , y_true : LABELS_VALUE_TYPE , y_pred : LABELS_VALUE_TYPE ) -> float :
21+ def __call__ (self , y_true : ListOfGenericLabels , y_pred : ListOfGenericLabels ) -> float :
2322 """
2423 Calculate decision metric.
2524
@@ -32,17 +31,14 @@ def __call__(self, y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> floa
3231 ...
3332
3433
35- def handle_oos (
36- y_true : list [LabelType | None ], y_pred : list [LabelType | None ]
37- ) -> tuple [list [LabelType ], list [LabelType ]]:
34+ def handle_oos (y_true : ListOfGenericLabels , y_pred : ListOfGenericLabels ) -> tuple [ListOfLabels , ListOfLabels ]:
3835 """Convert labels of OOS samples to make them usable in decision metrics."""
3936 in_domain_labels = list (filter (lambda lab : lab is not None , y_true ))
40- multilabel = isinstance (in_domain_labels [0 ], list )
41- if multilabel :
37+ if isinstance (in_domain_labels [0 ], list ):
4238 func = _add_oos_multilabel
4339 n_classes = len (in_domain_labels [0 ])
4440 else :
45- func = _add_oos_multiclass
41+ func = _add_oos_multiclass # type: ignore[assignment]
4642 n_classes = len (set (in_domain_labels ))
4743 func = partial (func , n_classes = n_classes )
4844 return list (map (func , y_true )), list (map (func , y_pred ))
@@ -60,7 +56,7 @@ def _add_oos_multilabel(label: list[int] | None, n_classes: int) -> list[int]:
6056 return [* label , 1 ]
6157
6258
63- def decision_accuracy (y_true : LABELS_VALUE_TYPE , y_pred : LABELS_VALUE_TYPE ) -> float :
59+ def decision_accuracy (y_true : ListOfGenericLabels , y_pred : ListOfGenericLabels ) -> float :
6460 r"""
6561 Calculate decision accuracy. Supports both multiclass and multilabel.
6662
@@ -131,7 +127,7 @@ def _decision_roc_auc_multilabel(y_true: npt.NDArray[Any], y_pred: npt.NDArray[A
131127 return float (roc_auc_score (y_true , y_pred , average = "macro" ))
132128
133129
134- def decision_roc_auc (y_true : LABELS_VALUE_TYPE , y_pred : LABELS_VALUE_TYPE ) -> float :
130+ def decision_roc_auc (y_true : ListOfGenericLabels , y_pred : ListOfGenericLabels ) -> float :
135131 r"""
136132 Calculate ROC AUC for multiclass and multilabel classification.
137133
@@ -153,7 +149,7 @@ def decision_roc_auc(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> fl
153149 raise ValueError (msg )
154150
155151
156- def decision_precision (y_true : LABELS_VALUE_TYPE , y_pred : LABELS_VALUE_TYPE ) -> float :
152+ def decision_precision (y_true : ListOfGenericLabels , y_pred : ListOfGenericLabels ) -> float :
157153 r"""
158154 Calculate decision precision. Supports both multiclass and multilabel.
159155
@@ -168,7 +164,7 @@ def decision_precision(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) ->
168164 return float (precision_score (* handle_oos (y_true , y_pred ), average = "macro" ))
169165
170166
171- def decision_recall (y_true : LABELS_VALUE_TYPE , y_pred : LABELS_VALUE_TYPE ) -> float :
167+ def decision_recall (y_true : ListOfGenericLabels , y_pred : ListOfGenericLabels ) -> float :
172168 r"""
173169 Calculate decision recall. Supports both multiclass and multilabel.
174170
@@ -183,7 +179,7 @@ def decision_recall(y_true: LABELS_VALUE_TYPE, y_pred: LABELS_VALUE_TYPE) -> flo
183179 return float (recall_score (* handle_oos (y_true , y_pred ), average = "macro" ))
184180
185181
186- def decision_f1 (y_true : LABELS_VALUE_TYPE , y_pred : LABELS_VALUE_TYPE ) -> float :
182+ def decision_f1 (y_true : ListOfGenericLabels , y_pred : ListOfGenericLabels ) -> float :
187183 r"""
188184 Calculate decision f1 score. Supports both multiclass and multilabel.
189185
0 commit comments