1313from autointent import Context , Embedder , VectorIndex
1414from autointent .context .optimization_info import RetrieverArtifact
1515from autointent .custom_types import BaseMetadataDict , LabelType
16- from autointent .metrics import RETRIEVAL_METRICS_MULTICLASS , RETRIEVAL_METRICS_MULTILABEL , ScoringMetricFn
16+ from autointent .metrics import (
17+ RETRIEVAL_METRICS_MULTICLASS ,
18+ RETRIEVAL_METRICS_MULTILABEL ,
19+ SCORING_METRICS_MULTICLASS ,
20+ SCORING_METRICS_MULTILABEL ,
21+ )
1722from autointent .modules .abc import EmbeddingModule
1823
1924
@@ -149,14 +154,12 @@ def score(
149154 self ,
150155 context : Context ,
151156 split : Literal ["validation" , "test" ],
152- metric_fn : ScoringMetricFn ,
153157 ) -> float :
154158 """
155159 Evaluate the model using a specified metric function.
156160
157161 :param context: The context containing test data and labels.
158162 :param split: Target split ("validation" or "test").
159- :param metric_fn: Function to compute the retrieval metric.
160163 :return: Computed metric score.
161164 """
162165 if split == "validation" :
@@ -173,7 +176,8 @@ def score(
173176 predicted_encoded = self .classifier .predict (embeddings )
174177 predicted_labels = self .label_encoder .inverse_transform (predicted_encoded )
175178
176- return metric_fn (labels , predicted_labels .reshape (- 1 , 1 ))
179+ metrics_dict = SCORING_METRICS_MULTILABEL if context .is_multilabel () else SCORING_METRICS_MULTICLASS
180+ return self .score_metrics (([labels ], [predicted_labels ]), metrics_dict )
177181
178182 def get_assets (self ) -> RetrieverArtifact :
179183 """
0 commit comments