22from dataclasses import dataclass , asdict
33import warnings
44from pathlib import Path
5- from typing import Dict , Optional
5+ from typing import Dict , Optional , Union , Callable , Any
66import fasttext
77from datasets import Dataset
88from sklearn .preprocessing import LabelEncoder
9- from sklearn . metrics import precision_score , recall_score
9+ import evaluate
1010from setfit .trainer import ColumnMappingMixin
1111from anyclassifier .fasttext_wrapper .config import FastTextConfig
1212from anyclassifier .fasttext_wrapper .model import FastTextForSequenceClassification
@@ -54,6 +54,13 @@ class FastTextTrainer(ColumnMappingMixin):
5454 The training dataset.
5555 eval_dataset (`Dataset`, *optional*):
5656 The evaluation dataset.
57+ metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`):
58+ The metric to use for evaluation. If a string is provided, we treat it as the metric
59+ name and load it with default settings. If a callable is provided, it must take two arguments
60+ (`y_pred`, `y_test`) and return a dictionary with metric keys to values.
61+ metric_kwargs (`Dict[str, Any]`, *optional*):
62+ Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1".
63+ For example useful for providing an averaging strategy for computing f1 in a multi-label setting.
5764 column_mapping (`Dict[str, str]`, *optional*):
5865 A mapping from the column names in the dataset to the column names expected by the model.
5966 The expected format is a dictionary with the following format:
@@ -66,13 +73,17 @@ def __init__(
6673 args : FastTextConfig ,
6774 train_dataset : Optional ["Dataset" ] = None ,
6875 eval_dataset : Optional ["Dataset" ] = None ,
76+ metric : Union [str , Callable [["Dataset" , "Dataset" ], Dict [str , float ]]] = "accuracy" ,
77+ metric_kwargs : Optional [Dict [str , Any ]] = None ,
6978 column_mapping : Optional [Dict [str , str ]] = None ,
7079 ) -> None :
7180 if args is not None and not isinstance (args , FastTextTrainingArguments ):
7281 raise ValueError ("`args` must be a `FastTextTrainingArguments` instance." )
7382 self .training_args = asdict (args )
7483 self .output_dir = self .training_args .pop ("output_dir" )
7584 self .data_txt_path = self .training_args .pop ("data_txt_path" )
85+ self .metric = metric
86+ self .metric_kwargs = metric_kwargs
7687 self .column_mapping = column_mapping
7788 if train_dataset :
7889 self ._validate_column_mapping (train_dataset )
@@ -165,10 +176,16 @@ def evaluate(self, dataset: Dataset) -> Dict[str, float]:
165176 le .fit (label + label_pred )
166177 label = le .transform (label )
167178 label_pred = le .transform (label_pred )
168- return {
169- "precision" : precision_score (label , label_pred , average = "micro" ),
170- "recall" : recall_score (label , label_pred , average = "micro" )
171- }
179+
180+ metric_kwargs = self .metric_kwargs or {}
181+ if isinstance (self .metric , str ):
182+ metric_fn = evaluate .load (self .metric )
183+ results = metric_fn .compute (predictions = y_pred , references = y_test , ** metric_kwargs )
184+ elif callable (self .metric ):
185+ results = self .metric (y_pred , y_test , ** metric_kwargs )
186+ else :
187+ raise ValueError ("metric must be a string or a callable" )
188+ return {"metric" : results }
172189
173190 def push_to_hub (self , repo_id : str , ** kwargs ) -> str :
174191 """Upload model checkpoint to the Hub using `huggingface_hub`.
0 commit comments