1616from autointent .metrics import (
1717 RETRIEVAL_METRICS_MULTICLASS ,
1818 RETRIEVAL_METRICS_MULTILABEL ,
19- SCORING_METRICS_MULTICLASS ,
20- SCORING_METRICS_MULTILABEL ,
2119)
2220from autointent .modules .abc import EmbeddingModule
2321
@@ -154,13 +152,13 @@ def score(
154152 self ,
155153 context : Context ,
156154 split : Literal ["validation" , "test" ],
157- ) -> float :
155+ ) -> dict [ str , float | str ] :
158156 """
159- Evaluate the model using a specified metric function .
157+ Evaluate the model using accuracy metric.
160158
161159 :param context: The context containing test data and labels.
162160 :param split: Target split ("validation" or "test").
163- :return: Computed metric score.
161+ :return: Computed accuracy score.
164162 """
165163 if split == "validation" :
166164 utterances = context .data_handler .validation_utterances (0 )
@@ -176,8 +174,10 @@ def score(
176174 predicted_encoded = self .classifier .predict (embeddings )
177175 predicted_labels = self .label_encoder .inverse_transform (predicted_encoded )
178176
179- metrics_dict = SCORING_METRICS_MULTILABEL if context .is_multilabel () else SCORING_METRICS_MULTICLASS
180- return self .score_metrics (([labels ], [predicted_labels ]), metrics_dict )
177+ correct_predictions = sum (1 for true , pred in zip (labels , predicted_labels , strict = False ) if true == pred )
178+ accuracy = correct_predictions / len (labels )
179+
180+ return {"scoring_accuracy" : accuracy }
181181
182182 def get_assets (self ) -> RetrieverArtifact :
183183 """
0 commit comments