Skip to content

Commit 732dfee

Browse files
committed
fix: added predict_proba
1 parent 0beb251 commit 732dfee

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

autointent/modules/embedding/_retrieval.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from autointent.metrics import (
1717
RETRIEVAL_METRICS_MULTICLASS,
1818
RETRIEVAL_METRICS_MULTILABEL,
19+
SCORING_METRICS_MULTICLASS,
20+
SCORING_METRICS_MULTILABEL,
1921
)
2022
from autointent.modules.abc import EmbeddingModule
2123

@@ -171,13 +173,9 @@ def score(
171173
raise ValueError(message)
172174

173175
embeddings = self.embedder.embed(utterances)
174-
predicted_encoded = self.classifier.predict(embeddings)
175-
predicted_labels = self.label_encoder.inverse_transform(predicted_encoded)
176-
177-
correct_predictions = sum(1 for true, pred in zip(labels, predicted_labels, strict=False) if true == pred)
178-
accuracy = correct_predictions / len(labels)
179-
180-
return {"scoring_accuracy": accuracy}
176+
predicted_encoded = self.classifier.predict_proba(embeddings)
177+
metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS
178+
return self.score_metrics((labels, predicted_encoded), metrics_dict)
181179

182180
def get_assets(self) -> RetrieverArtifact:
183181
"""

0 commit comments

Comments
 (0)