Skip to content

Commit 73a87eb

Browse files
SeBorgeyvoorhs
andauthored
balancer jupiter notebook (#160)
* balancer jupiter notebook * convert to rst * update pyproject * add references to tutorials --------- Co-authored-by: voorhs <[email protected]>
1 parent ed981ce commit 73a87eb

File tree

7 files changed

+217
-11
lines changed

7 files changed

+217
-11
lines changed

autointent/generation/utterances/balancer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ class DatasetBalancer:
2020
If your dataset is unbalanced, you can add LLM-generated samples.
2121
This method uses :py:class:`autointent.generation.utterances.UtteranceGenerator` under the hood.
2222
23+
See tutorial :ref:`balancer_aug` for usage examples.
24+
2325
Args:
2426
generator (Generator): The generator object used to create utterances.
2527
prompt_maker (Callable[[Intent, int], list[Message]]): A callable that creates prompts for the generator.
Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1+
from .dspy_evolver import DSPYIncrementalUtteranceEvolver
12
from .evolver import UtteranceEvolver
23
from .incremental_evolver import IncrementalUtteranceEvolver
34

4-
__all__ = [
5-
"IncrementalUtteranceEvolver",
6-
"UtteranceEvolver",
7-
]
5+
__all__ = ["DSPYIncrementalUtteranceEvolver", "IncrementalUtteranceEvolver", "UtteranceEvolver"]

autointent/generation/utterances/evolution/dspy_evolver.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ class DSPYIncrementalUtteranceEvolver:
152152
153153
For scoring generations it would use modified SemanticF1 as the base metric with a ROUGE-1 as repetition penalty.
154154
155+
See tutorial :ref:`evolutionary_strategy_augmentation` for usage examples.
156+
155157
Args:
156158
model: Model name. This should follow naming schema from `litellm providers <https://docs.litellm.ai/docs/providers>`_.
157159
api_base: API base URL. Some models require this.
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
.. _balancer_aug:
2+
3+
Balancing Datasets with DatasetBalancer
4+
=======================================
5+
6+
This guide demonstrates how to use the DatasetBalancer class to balance class distribution in your datasets through LLM-based data augmentation.
7+
8+
.. contents:: Table of Contents
9+
:depth: 2
10+
11+
Why Balance Datasets?
12+
---------------------
13+
14+
Imbalanced datasets can lead to biased models that perform well on majority classes but poorly on minority classes. DatasetBalancer helps address this issue by generating additional examples for underrepresented classes using large language models.
15+
16+
Creating a Sample Imbalanced Dataset
17+
-----------------------------------
18+
19+
Let's create a small imbalanced dataset to demonstrate the balancing process:
20+
21+
.. code-block:: python
22+
23+
from autointent import Dataset
24+
from autointent.generation.utterances.balancer import DatasetBalancer
25+
from autointent.generation.utterances.generator import Generator
26+
from autointent.generation.chat_templates import EnglishSynthesizerTemplate
27+
28+
# Create a simple imbalanced dataset
29+
sample_data = {
30+
"intents": [
31+
{"id": 0, "name": "restaurant_booking", "description": "Booking a table at a restaurant"},
32+
{"id": 1, "name": "weather_query", "description": "Checking weather conditions"},
33+
{"id": 2, "name": "navigation", "description": "Getting directions to a location"},
34+
],
35+
"train": [
36+
# Restaurant booking examples (5)
37+
{"utterance": "Book a table for two tonight", "label": 0},
38+
{"utterance": "I need a reservation at Le Bistro", "label": 0},
39+
{"utterance": "Can you reserve a table for me?", "label": 0},
40+
{"utterance": "I want to book a restaurant for my anniversary", "label": 0},
41+
{"utterance": "Make a dinner reservation for 8pm", "label": 0},
42+
43+
# Weather query examples (3)
44+
{"utterance": "What's the weather like today?", "label": 1},
45+
{"utterance": "Will it rain tomorrow?", "label": 1},
46+
{"utterance": "Weather forecast for New York", "label": 1},
47+
48+
# Navigation example (1)
49+
{"utterance": "How do I get to the museum?", "label": 2},
50+
]
51+
}
52+
53+
# Create the dataset
54+
dataset = Dataset.from_dict(sample_data)
55+
56+
Setting up the Generator and Template
57+
------------------------------------
58+
59+
DatasetBalancer requires two main components:
60+
1. A Generator - responsible for creating new utterances using an LLM
61+
2. A Template - defines the prompt format sent to the LLM
62+
63+
Let's set up these components:
64+
65+
.. code-block:: python
66+
67+
# Initialize a generator (uses OpenAI API by default)
68+
generator = Generator()
69+
70+
# Create a template for generating utterances
71+
template = EnglishSynthesizerTemplate(dataset=dataset, split="train")
72+
73+
Creating the DatasetBalancer
74+
----------------------------
75+
76+
Now we can create our DatasetBalancer instance:
77+
78+
.. code-block:: python
79+
80+
balancer = DatasetBalancer(
81+
generator=generator,
82+
prompt_maker=template,
83+
async_mode=False, # Set to True for faster generation with async processing
84+
max_samples_per_class=5, # Each class will have exactly 5 samples after balancing
85+
)
86+
87+
Checking Initial Class Distribution
88+
----------------------------------
89+
90+
Let's examine the class distribution before balancing:
91+
92+
.. code-block:: python
93+
94+
# Check the initial distribution of classes in the training set
95+
initial_distribution = {}
96+
for sample in dataset["train"]:
97+
label = sample[Dataset.label_feature]
98+
initial_distribution[label] = initial_distribution.get(label, 0) + 1
99+
100+
print("Initial class distribution:")
101+
for class_id, count in sorted(initial_distribution.items()):
102+
intent = next(i for i in dataset.intents if i.id == class_id)
103+
print(f"Class {class_id} ({intent.name}): {count} samples")
104+
105+
print(f"\nMost represented class: {max(initial_distribution.values())} samples")
106+
print(f"Least represented class: {min(initial_distribution.values())} samples")
107+
108+
Balancing the Dataset
109+
---------------------
110+
111+
Now we'll use the DatasetBalancer to augment our dataset:
112+
113+
.. code-block:: python
114+
115+
# Create a copy of the dataset
116+
dataset_copy = Dataset.from_dict(dataset.to_dict())
117+
118+
# Balance the training split
119+
balanced_dataset = balancer.balance(
120+
dataset=dataset_copy,
121+
split="train",
122+
batch_size=2, # Process generations in batches of 2
123+
)
124+
125+
Checking the Results
126+
-------------------
127+
128+
Let's examine the class distribution after balancing:
129+
130+
.. code-block:: python
131+
132+
# Check the balanced distribution
133+
balanced_distribution = {}
134+
for sample in balanced_dataset["train"]:
135+
label = sample[Dataset.label_feature]
136+
balanced_distribution[label] = balanced_distribution.get(label, 0) + 1
137+
138+
print("Balanced class distribution:")
139+
for class_id, count in sorted(balanced_distribution.items()):
140+
intent = next(i for i in dataset.intents if i.id == class_id)
141+
print(f"Class {class_id} ({intent.name}): {count} samples")
142+
143+
print(f"\nMost represented class: {max(balanced_distribution.values())} samples")
144+
print(f"Least represented class: {min(balanced_distribution.values())} samples")
145+
146+
Examining Generated Examples
147+
---------------------------
148+
149+
Let's look at some examples of original and generated utterances for the navigation class,
150+
which was the most underrepresented:
151+
152+
.. code-block:: python
153+
154+
# Navigation class (Class 2)
155+
navigation_class_id = 2
156+
intent = next(i for i in dataset.intents if i.id == navigation_class_id)
157+
158+
print(f"Examples for class {navigation_class_id} ({intent.name}):")
159+
160+
# Original examples
161+
original_examples = [
162+
s[Dataset.utterance_feature] for s in dataset["train"] if s[Dataset.label_feature] == navigation_class_id
163+
]
164+
print("\nOriginal examples:")
165+
for i, example in enumerate(original_examples, 1):
166+
print(f"{i}. {example}")
167+
168+
# Generated examples
169+
all_examples = [
170+
s[Dataset.utterance_feature] for s in balanced_dataset["train"] if s[Dataset.label_feature] == navigation_class_id
171+
]
172+
generated_examples = [ex for ex in all_examples if ex not in original_examples]
173+
print("\nGenerated examples:")
174+
for i, example in enumerate(generated_examples, 1):
175+
print(f"{i}. {example}")
176+
177+
Configuring the Number of Samples per Class
178+
------------------------------------------
179+
180+
You can configure how many samples each class should have:
181+
182+
.. code-block:: python
183+
184+
# To bring all classes to exactly 10 samples
185+
original_dataset = Dataset.from_dict(sample_data)
186+
exact_template = EnglishSynthesizerTemplate(dataset=original_dataset, split="train")
187+
188+
exact_balancer = DatasetBalancer(
189+
generator=generator,
190+
prompt_maker=exact_template,
191+
max_samples_per_class=10
192+
)
193+
194+
# Balance to the level of the most represented class
195+
max_template = EnglishSynthesizerTemplate(dataset=original_dataset, split="train")
196+
197+
max_balancer = DatasetBalancer(
198+
generator=generator,
199+
prompt_maker=max_template,
200+
max_samples_per_class=None # Will use the count of the most represented class
201+
)
202+
203+
Tips for Effective Dataset Balancing
204+
-----------------------------------
205+
206+
1. **Quality Control**: Always review a sample of generated utterances to ensure quality.
207+
2. **Template Selection**: Different templates may work better for different domains.
208+
3. **Model Selection**: Larger models generally produce higher quality utterances.
209+
4. **Batch Size**: Increase batch size for faster generation if your hardware allows.
210+
5. **Validation**: Test your model on both original and augmented data to ensure it generalizes well.

docs/source/augmentation_tutorials/dspy_augmentation.rst

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@ This tutorial covers the implementation and usage of an evolutionary strategy to
88
.. contents:: Table of Contents
99
:depth: 2
1010

11-
-------------
1211
What is DSPy?
1312
-------------
1413

1514
DSPy is a framework for optimizing and evaluating language models. It provides tools for defining signatures, optimizing modules, and measuring evaluation metrics. This module leverages DSPy to generate augmented utterances using an evolutionary approach.
1615

17-
---------------------
1816
How This Module Works
1917
---------------------
2018

@@ -26,7 +24,6 @@ This module applies an incremental evolutionary strategy for augmenting utteranc
2624

2725
The augmentation process runs for a specified number of evolutions, saving intermediate models and optimizing the results.
2826

29-
------------
3027
Installation
3128
------------
3229

@@ -36,7 +33,6 @@ Ensure you have the required dependencies installed:
3633
3734
pip install "autointent[dspy]"
3835
39-
--------------
4036
Scoring Metric
4137
--------------
4238

@@ -54,7 +50,6 @@ The scoring metric consists of:
5450
- `Final Score = SemanticF1 * Repetition Factor`
5551
- A higher score means better augmentation.
5652

57-
-------------
5853
Usage Example
5954
-------------
6055

docs/source/user_guides.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ Data augmentation tutorials
1616
:maxdepth: 1
1717

1818
augmentation_tutorials/dspy_augmentation
19+
augmentation_tutorials/balancer
1920

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ dependencies = [
3636
"scikit-learn (>=1.5,<2.0)",
3737
"scikit-multilearn (==0.2.0)",
3838
"appdirs (>=1.4,<2.0)",
39-
"sre-yield (>=1.2,<2.0)",
4039
"optuna (>=4.0.0,<5.0.0)",
41-
"xeger (>=0.4.0,<0.5.0)",
4240
"pathlib (>=1.0.1,<2.0.0)",
4341
"pydantic (>=2.10.5,<3.0.0)",
4442
"faiss-cpu (>=1.9.0,<2.0.0)",

0 commit comments

Comments
 (0)