Skip to content

Commit e1e6ca1

Browse files
committed
- feat: standardise metric and metric_kwargs
- feat: logging
1 parent fd8e9b1 commit e1e6ca1

File tree

3 files changed

+53
-10
lines changed

3 files changed

+53
-10
lines changed

anyclassifier/annotation/annotator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1+
import sys
12
from abc import abstractmethod, ABCMeta
23
from typing import Union, Optional
34
import re
5+
from collections import Counter
46
from tqdm import tqdm
7+
import logging
58
from llama_cpp import Llama
69
from datasets import Dataset # it is import to load llama_cpp first before datasets to prevent error like https://github.com/abetlen/llama-cpp-python/issues/806
710
from huggingface_hub import hf_hub_download
811
from anyclassifier.annotation.prompt import AnnotationPrompt
912

1013

14+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
15+
16+
1117
class AnnotatorBase(metaclass=ABCMeta):
1218

1319
regex_pattern = re.compile(r'Label:\s*(.+)')
@@ -70,6 +76,9 @@ def annotate_dataset(self,
7076

7177
selected_dataset = selected_dataset.add_column("label", label_list)
7278
selected_dataset = selected_dataset.filter(lambda x: x.get("label") is not None)
79+
logging.info(f"""Count of labels
80+
{Counter(selected_dataset["label"]).most_common(len(self._prompt.label_definition))}
81+
""")
7382
return selected_dataset
7483

7584

anyclassifier/fasttext_wrapper/trainer.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
from dataclasses import dataclass, asdict
33
import warnings
44
from pathlib import Path
5-
from typing import Dict, Optional
5+
from typing import Dict, Optional, Union, Callable, Any
66
import fasttext
77
from datasets import Dataset
88
from sklearn.preprocessing import LabelEncoder
9-
from sklearn.metrics import precision_score, recall_score
9+
import evaluate
1010
from setfit.trainer import ColumnMappingMixin
1111
from anyclassifier.fasttext_wrapper.config import FastTextConfig
1212
from anyclassifier.fasttext_wrapper.model import FastTextForSequenceClassification
@@ -54,6 +54,13 @@ class FastTextTrainer(ColumnMappingMixin):
5454
The training dataset.
5555
eval_dataset (`Dataset`, *optional*):
5656
The evaluation dataset.
57+
metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`):
58+
The metric to use for evaluation. If a string is provided, we treat it as the metric
59+
name and load it with default settings. If a callable is provided, it must take two arguments
60+
(`y_pred`, `y_test`) and return a dictionary with metric keys to values.
61+
metric_kwargs (`Dict[str, Any]`, *optional*):
62+
Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1".
63+
For example useful for providing an averaging strategy for computing f1 in a multi-label setting.
5764
column_mapping (`Dict[str, str]`, *optional*):
5865
A mapping from the column names in the dataset to the column names expected by the model.
5966
The expected format is a dictionary with the following format:
@@ -66,13 +73,17 @@ def __init__(
6673
args: FastTextConfig,
6774
train_dataset: Optional["Dataset"] = None,
6875
eval_dataset: Optional["Dataset"] = None,
76+
metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy",
77+
metric_kwargs: Optional[Dict[str, Any]] = None,
6978
column_mapping: Optional[Dict[str, str]] = None,
7079
) -> None:
7180
if args is not None and not isinstance(args, FastTextTrainingArguments):
7281
raise ValueError("`args` must be a `FastTextTrainingArguments` instance.")
7382
self.training_args = asdict(args)
7483
self.output_dir = self.training_args.pop("output_dir")
7584
self.data_txt_path = self.training_args.pop("data_txt_path")
85+
self.metric = metric
86+
self.metric_kwargs = metric_kwargs
7687
self.column_mapping = column_mapping
7788
if train_dataset:
7889
self._validate_column_mapping(train_dataset)
@@ -165,10 +176,16 @@ def evaluate(self, dataset: Dataset) -> Dict[str, float]:
165176
le.fit(label + label_pred)
166177
label = le.transform(label)
167178
label_pred = le.transform(label_pred)
168-
return {
169-
"precision": precision_score(label, label_pred, average="micro"),
170-
"recall": recall_score(label, label_pred, average="micro")
171-
}
179+
180+
metric_kwargs = self.metric_kwargs or {}
181+
if isinstance(self.metric, str):
182+
metric_fn = evaluate.load(self.metric)
183+
results = metric_fn.compute(predictions=y_pred, references=y_test, **metric_kwargs)
184+
elif callable(self.metric):
185+
results = self.metric(y_pred, y_test, **metric_kwargs)
186+
else:
187+
raise ValueError("metric must be a string or a callable")
188+
return {"metric": results}
172189

173190
def push_to_hub(self, repo_id: str, **kwargs) -> str:
174191
"""Upload model checkpoint to the Hub using `huggingface_hub`.

anyclassifier/train_any.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import List, Dict, Union, Literal, Optional
1+
import sys
2+
from typing import List, Dict, Union, Literal, Optional, Callable, Any
3+
import logging
24
from datasets import Dataset
35
from huggingface_hub import interpreter_login
46
from setfit import SetFitModel, TrainingArguments, Trainer as SetFitTrainer
@@ -9,6 +11,9 @@
911
)
1012

1113

14+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
15+
16+
1217
def train_anyclassifier(
1318
instruction: str,
1419
annotator_model_path: str,
@@ -22,6 +27,8 @@ def train_anyclassifier(
2227
batch_size: Optional[int] = 16,
2328
n_record_to_label: int = 100,
2429
test_size: float = 0.3,
30+
metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy",
31+
metric_kwargs: Optional[Dict[str, Any]] = None,
2532
push_dataset_to_hub: bool = False,
2633
dataset_repo_id: Optional[str] = None,
2734
is_dataset_private: Optional[bool] = True,
@@ -55,6 +62,13 @@ def train_anyclassifier(
5562
No of record for LLM to label
5663
test_size (`float`, *optional*):
5764
Proportion of labeled data to evaluation
65+
metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`):
66+
The metric to use for evaluation. If a string is provided, we treat it as the metric
67+
name and load it with default settings. If a callable is provided, it must take two arguments
68+
(`y_pred`, `y_test`) and return a dictionary with metric keys to values.
69+
metric_kwargs (`Dict[str, Any]`, *optional*):
70+
Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1".
71+
For example useful for providing an averaging strategy for computing f1 in a multi-label setting.
5872
push_dataset_to_hub (`bool`, *optional*):
5973
Whether to push dataset to huggingface hub for reuse, highly recommended to do so.
6074
dataset_repo_id (`str`, *optional*):
@@ -92,13 +106,15 @@ def train_anyclassifier(
92106
args=args,
93107
train_dataset=label_dataset["train"],
94108
eval_dataset=label_dataset["test"],
109+
metric=metric,
110+
metric_kwargs=metric_kwargs,
95111
column_mapping={**column_mapping, "label": "label"},
96112
)
97113

98114
# Train and evaluate
99115
trainer.train()
100116
metrics = trainer.evaluate(label_dataset["test"])
101-
print(metrics)
117+
logging.info(metrics)
102118
return trainer
103119

104120
elif model_type == "setfit":
@@ -120,14 +136,15 @@ def train_anyclassifier(
120136
args=args,
121137
train_dataset=label_dataset["train"],
122138
eval_dataset=label_dataset["test"],
123-
metric="accuracy",
139+
metric=metric,
140+
metric_kwargs=metric_kwargs,
124141
column_mapping={**column_mapping, "label": "label"},
125142
)
126143

127144
# Train and evaluate
128145
trainer.train()
129146
metrics = trainer.evaluate(label_dataset["test"])
130-
print(metrics)
147+
logging.info(metrics)
131148
return trainer
132149
else:
133150
raise NotImplementedError("other approach is not implemented yet")

0 commit comments

Comments
 (0)