Skip to content

Commit 1422dc3

Browse files
committed
- feat: parse_output only capture predefined label name
- feat: add max_length_for_label to handle long document - feat: prompt change - refactor: train_anyclassifier posItional args reshuffle - docs: update README.md
1 parent 312c611 commit 1422dc3

File tree

6 files changed

+40
-19
lines changed

6 files changed

+40
-19
lines changed

README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,18 @@ Together let's build more useful models.
1717

1818
## 🚀 Features
1919
- One line to build any classifier that you don't have data 🤯
20-
- Why one line? Because it can easily be used by other LLM as a function call, easily to be integrated with any **agentic flow**
20+
- Why one line? Not only it is easy to be used by Human but also it can easily be used by other LLM as a function call, easily to be integrated with any **agentic flow**
2121
- Smoothness integration with transformers, setfit, fasttext and datasets
2222
- [setfit](https://github.com/huggingface/setfit): for limited data (e.g. 100) 🤗
2323
- [fastText](https://github.com/facebookresearch/fastText): for blazingly fast inference (1000 docs/s) without GPU ⚡️
2424
- [transformers](https://github.com/huggingface/transformers): for other usecase
2525
- Huggingface-like interface for fastText that supports push_to_hub, saving and loading (let's not forget this amazing model before transformers architecture).
2626

2727
## 🏁 QuickStart in Colab
28-
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LB8PUTT9wM1Qb2cY-6Dx-RNiqmyCvRr1?usp=sharing)
28+
| Dataset | Colab Link |
29+
|-------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
30+
| imdb sentiment classification | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LB8PUTT9wM1Qb2cY-6Dx-RNiqmyCvRr1?usp=sharing) |
31+
2932

3033
## 🔧 Installation
3134
It is using llama.cpp as backend, and build wheel can take a lot of time (10min+), as such, we also provide an instruction to install with pre-built wheel.
@@ -82,11 +85,11 @@ unlabeled_dataset # a huggingface datasets.Dataset class can be from your local
8285
# Magic One Line!
8386
trainer = build_anyclassifier(
8487
"Classify a text's sentiment.",
85-
hf_hub_download("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q8_0.gguf"), # as you like
8688
[
8789
Label(name='1', desc='positive sentiment'),
8890
Label(name='0', desc='negative sentiment')
8991
],
92+
hf_hub_download("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q8_0.gguf"), # as you like
9093
unlabeled_dataset,
9194
column_mapping={"text": "text"},
9295
model_type="setfit", # can be set to fastText
@@ -137,13 +140,14 @@ label_dataset.push_to_hub('user_id/any_data')
137140

138141
```
139142

140-
See examples:
143+
See more examples:
141144

142145
| model_type | example | resulting model | dataset |
143146
|------------|------------------------------------------|----------------------------------------------------------------------------------|------------------------------------------------------------------------------|
144147
| setfit | [link](examples/train_setfit_model.py) | [link](https://huggingface.co/kenhktsui/anyclassifier_setfit_demo) | [link](https://huggingface.co/datasets/kenhktsui/anyclassifier_dataset_demo) |
145148
| fasttext | [link](examples/train_fasttext_model.py) | [link](https://huggingface.co/kenhktsui/fasttext_test)(probably need more label) | [link](https://huggingface.co/datasets/kenhktsui/anyclassifier_dataset_demo) |
146149

150+
Test accuracy on imdb with SetFit: 90.42%
147151

148152
## 🗺️ Roadmap
149153
- High Quality Data:

anyclassifier/annotation/annotator.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,33 @@
11
import sys
22
from abc import abstractmethod, ABCMeta
3-
from typing import Union, Optional
3+
from typing import Union, Optional, List
44
import re
55
from collections import Counter
66
from tqdm import tqdm
77
import logging
88
from llama_cpp import Llama
99
from datasets import Dataset # it is import to load llama_cpp first before datasets to prevent error like https://github.com/abetlen/llama-cpp-python/issues/806
1010
from huggingface_hub import hf_hub_download
11-
from anyclassifier.annotation.prompt import AnnotationPrompt
11+
from anyclassifier.annotation.prompt import AnnotationPrompt, Label
1212

1313

1414
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
1515

1616

1717
class AnnotatorBase(metaclass=ABCMeta):
18+
def __init__(self):
19+
self.regex_pattern = None
20+
21+
def prepare_regex_pattern(self, labels: List[Label]):
22+
labels_str = "|".join(l.name for l in labels)
23+
self.regex_pattern = re.compile(rf'Label:\s*({labels_str})')
1824

19-
regex_pattern = re.compile(r'Label:\s*(.+)')
2025
@abstractmethod
2126
def annotate(self, text: str) -> str:
2227
pass
2328

24-
@classmethod
25-
def parse_output(cls, text: str) -> Optional[str]:
26-
match = cls.regex_pattern.search(text)
29+
def parse_output(self, text: str) -> Optional[str]:
30+
match = self.regex_pattern.search(text)
2731
if match:
2832
return match.group(1)
2933
return None
@@ -36,6 +40,8 @@ def __init__(self,
3640
"Meta-Llama-3.1-8B-Instruct-Q8_0.gguf"),
3741
n_gpu_layers: int = -1,
3842
n_ctx: int = 2048):
43+
super().__init__()
44+
self.prepare_regex_pattern(prompt.label_definition)
3945
self._prompt = prompt
4046
self._llm = Llama(model_path=model_path,
4147
n_gpu_layers=n_gpu_layers,
@@ -59,7 +65,8 @@ def annotate(self, text: str) -> str:
5965
def annotate_dataset(self,
6066
dataset: Union[Dataset],
6167
text_col: str = "text",
62-
n_record: int = 1000,
68+
n_record: int = 200,
69+
max_length_for_labeling: int = 1500,
6370
shuffle: bool = True
6471
) -> Dataset:
6572
# shuffle the data to randomise potential bias in data collection process
@@ -70,7 +77,7 @@ def annotate_dataset(self,
7077

7178
label_list = []
7279
for d in tqdm(selected_dataset, desc="Annotating dataset"):
73-
llm_output = self.annotate(d[text_col])
80+
llm_output = self.annotate(d[text_col][:max_length_for_labeling])
7481
label = self.parse_output(llm_output)
7582
label_list.append(label)
7683

anyclassifier/annotation/prompt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def get_prompt(self, text: str):
3535
[f"Example {i+1}.\nText: {fse.text}\nLabel: {fse.label}" for i, fse in enumerate(self.few_shot_examples)]
3636
)
3737
return f"""{self.task_description}
38-
Here are the label definitions:
38+
Here are the label names and description:
3939
{label_defn_str}
4040
4141
Here is the text to be analyzed:

anyclassifier/train_any.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
def train_anyclassifier(
1818
instruction: str,
19-
annotator_model_path: str,
2019
labels: List[Label],
20+
annotator_model_path: str,
2121
unlabeled_dataset: Dataset,
2222
column_mapping: Dict[str, str] = {"text": "text"},
2323
model_type: Literal["setfit", "fasttext", "transformers"] = "setfit",
@@ -26,6 +26,7 @@ def train_anyclassifier(
2626
num_epochs: Optional[int] = 5,
2727
batch_size: Optional[int] = 16,
2828
n_record_to_label: int = 100,
29+
max_length_for_labeling: int = 1500,
2930
test_size: float = 0.3,
3031
metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy",
3132
metric_kwargs: Optional[Dict[str, Any]] = None,
@@ -38,10 +39,10 @@ def train_anyclassifier(
3839
Args:
3940
instruction (`str`):
4041
The instruction to LLM annotator
41-
annotator_model_path (`str`):
42-
The LLM annotator model to be used by llama.cpp
4342
labels (`List[Label]`):
4443
The labels including name and desc you want to classify
44+
annotator_model_path (`str`):
45+
The path of LLM annotator model to be used by llama.cpp
4546
unlabeled_dataset ('Dataset'):
4647
The unlabeled dataset you want to label.
4748
column_mapping (`Dict[str, str]`, *optional*):
@@ -60,6 +61,11 @@ def train_anyclassifier(
6061
Batch size to train model
6162
n_record_to_label (`int`, *optional*):
6263
No of record for LLM to label
64+
max_length_for_labeling (`int`, *optional*):
65+
Max length on character level to avoid exceeding context length of LLM and faster annotation. In general,
66+
how limiting truncating document affects the accuracy of annotation process depending on various
67+
factors, like complexity of classification, location of key information. If the same topic is conveyed
68+
throughout a document (e.g. sentiment analysis, domain classification), the impact is expected to be low.
6369
test_size (`float`, *optional*):
6470
Proportion of labeled data to evaluation
6571
metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`):
@@ -87,7 +93,11 @@ def train_anyclassifier(
8793
few_shot_examples=few_shot_examples
8894
)
8995
annotator = LlamaCppAnnotator(prompt, annotator_model_path)
90-
label_dataset = annotator.annotate_dataset(unlabeled_dataset, n_record=n_record_to_label)
96+
label_dataset = annotator.annotate_dataset(
97+
unlabeled_dataset,
98+
n_record=n_record_to_label,
99+
max_length_for_labeling=max_length_for_labeling
100+
)
91101

92102
label_dataset = label_dataset.train_test_split(test_size=test_size)
93103

examples/train_fasttext_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414

1515
trainer = train_anyclassifier(
1616
"Classify a text's sentiment.",
17-
hf_hub_download("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q8_0.gguf"),
1817
[
1918
Label(name='1', desc='positive sentiment'),
2019
Label(name='0', desc='negative sentiment')
2120
],
21+
hf_hub_download("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q8_0.gguf"),
2222
unlabeled_dataset,
2323
column_mapping={"text": "text"},
2424
model_type="fasttext",

examples/train_setfit_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414

1515
trainer = train_anyclassifier(
1616
"Classify a text's sentiment.",
17-
hf_hub_download("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q8_0.gguf"),
1817
[
1918
Label(name='1', desc='positive sentiment'),
2019
Label(name='0', desc='negative sentiment')
2120
],
21+
hf_hub_download("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q8_0.gguf"),
2222
unlabeled_dataset,
2323
column_mapping={"text": "text"},
2424
model_type="setfit",

0 commit comments

Comments
 (0)