Skip to content

Commit 599c262

Browse files
committed
fix: fixed score func
1 parent 36d704c commit 599c262

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

autointent/modules/embedding/_retrieval.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from autointent import Context, Embedder, VectorIndex
1414
from autointent.context.optimization_info import RetrieverArtifact
1515
from autointent.custom_types import BaseMetadataDict, LabelType
16-
from autointent.metrics import RETRIEVAL_METRICS_MULTICLASS, RETRIEVAL_METRICS_MULTILABEL, ScoringMetricFn
16+
from autointent.metrics import (
17+
RETRIEVAL_METRICS_MULTICLASS,
18+
RETRIEVAL_METRICS_MULTILABEL,
19+
SCORING_METRICS_MULTICLASS,
20+
SCORING_METRICS_MULTILABEL,
21+
)
1722
from autointent.modules.abc import EmbeddingModule
1823

1924

@@ -149,14 +154,12 @@ def score(
149154
self,
150155
context: Context,
151156
split: Literal["validation", "test"],
152-
metric_fn: ScoringMetricFn,
153157
) -> float:
154158
"""
155159
Evaluate the model using a specified metric function.
156160
157161
:param context: The context containing test data and labels.
158162
:param split: Target split ("validation" or "test").
159-
:param metric_fn: Function to compute the retrieval metric.
160163
:return: Computed metric score.
161164
"""
162165
if split == "validation":
@@ -173,7 +176,8 @@ def score(
173176
predicted_encoded = self.classifier.predict(embeddings)
174177
predicted_labels = self.label_encoder.inverse_transform(predicted_encoded)
175178

176-
return metric_fn(labels, predicted_labels.reshape(-1, 1))
179+
metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS
180+
return self.score_metrics(([labels], [predicted_labels]), metrics_dict)
177181

178182
def get_assets(self) -> RetrieverArtifact:
179183
"""

0 commit comments

Comments
 (0)