Skip to content

Commit 0beb251

Browse files
committed
fix: added accuracy for scorer in logreg
1 parent 599c262 commit 0beb251

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

autointent/modules/embedding/_retrieval.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
from autointent.metrics import (
1717
RETRIEVAL_METRICS_MULTICLASS,
1818
RETRIEVAL_METRICS_MULTILABEL,
19-
SCORING_METRICS_MULTICLASS,
20-
SCORING_METRICS_MULTILABEL,
2119
)
2220
from autointent.modules.abc import EmbeddingModule
2321

@@ -154,13 +152,13 @@ def score(
154152
self,
155153
context: Context,
156154
split: Literal["validation", "test"],
157-
) -> float:
155+
) -> dict[str, float | str]:
158156
"""
159-
Evaluate the model using a specified metric function.
157+
Evaluate the model using accuracy metric.
160158
161159
:param context: The context containing test data and labels.
162160
:param split: Target split ("validation" or "test").
163-
:return: Computed metric score.
161+
:return: Computed accuracy score.
164162
"""
165163
if split == "validation":
166164
utterances = context.data_handler.validation_utterances(0)
@@ -176,8 +174,10 @@ def score(
176174
predicted_encoded = self.classifier.predict(embeddings)
177175
predicted_labels = self.label_encoder.inverse_transform(predicted_encoded)
178176

179-
metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS
180-
return self.score_metrics(([labels], [predicted_labels]), metrics_dict)
177+
correct_predictions = sum(1 for true, pred in zip(labels, predicted_labels, strict=False) if true == pred)
178+
accuracy = correct_predictions / len(labels)
179+
180+
return {"scoring_accuracy": accuracy}
181181

182182
def get_assets(self) -> RetrieverArtifact:
183183
"""

0 commit comments

Comments
 (0)