Skip to content

Commit a5b777e

Browse files
authored
аугментация с целью балансировки датасета (#148)
* skeleton of code balancer * first working tested balancer * review changes * fix autocheck * fixcheck 2: return of the fixcheck * fix after second review * fix after pull dev * check autofix: episode 3
1 parent 426af0d commit a5b777e

File tree

3 files changed

+257
-0
lines changed

3 files changed

+257
-0
lines changed

autointent/generation/utterances/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .balancer import DatasetBalancer
12
from .basic import EnglishSynthesizerTemplate, RussianSynthesizerTemplate, UtteranceGenerator
23
from .evolution import (
34
AbstractEvolution,
@@ -16,6 +17,7 @@
1617
__all__ = [
1718
"AbstractEvolution",
1819
"ConcreteEvolution",
20+
"DatasetBalancer",
1921
"EvolutionChatTemplate",
2022
"FormalEvolution",
2123
"FunnyEvolution",
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
"""Module for balancing datasets through augmentation of underrepresented classes."""
2+
3+
import logging
4+
from collections import defaultdict
5+
6+
from datasets import Dataset as HFDataset
7+
8+
from autointent import Dataset
9+
from autointent.custom_types import Split
10+
from autointent.generation.utterances.basic.chat_templates._base import BaseSynthesizerTemplate
11+
from autointent.generation.utterances.basic.utterance_generator import UtteranceGenerator
12+
from autointent.generation.utterances.generator import Generator
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class DatasetBalancer:
18+
"""Class for balancing dataset through example augmentation."""
19+
20+
def __init__(
21+
self,
22+
generator: Generator,
23+
prompt_maker: BaseSynthesizerTemplate,
24+
async_mode: bool = False,
25+
max_samples_per_class: int | None = None,
26+
) -> None:
27+
"""
28+
Initialize the UtteranceBalancer.
29+
30+
Args:
31+
generator (Generator): The generator object used to create utterances.
32+
prompt_maker (Callable[[Intent, int], list[Message]]): A callable that creates prompts for the generator.
33+
seed (int, optional): The seed for random number generation. Defaults to 42.
34+
async_mode (bool, optional): Whether to run the generator in asynchronous mode. Defaults to False.
35+
max_samples_per_class (int | None, optional): The maximum number of samples per class.
36+
Must be a positive integer or None. Defaults to None.
37+
Raises:
38+
ValueError: If max_samples_per_class is not None and is less than or equal to 0.
39+
"""
40+
if max_samples_per_class is not None and max_samples_per_class <= 0:
41+
msg = "max_samples_per_class must be a positive integer or None"
42+
raise ValueError(msg)
43+
44+
self.utterance_generator = UtteranceGenerator(
45+
generator=generator, prompt_maker=prompt_maker, async_mode=async_mode
46+
)
47+
self.max_samples = max_samples_per_class
48+
49+
def balance(self, dataset: Dataset, split: str = Split.TRAIN, batch_size: int = 4) -> Dataset:
50+
"""
51+
Balances the specified dataset split.
52+
53+
:param dataset: Source dataset
54+
:param split: Target split for balancing
55+
:param n_evolutions: Number of augmentations per example
56+
:param batch_size: Batch size for asynchronous processing
57+
:return: Balanced dataset
58+
"""
59+
if dataset.multilabel:
60+
msg = "Method supports only single-label datasets"
61+
raise ValueError(msg)
62+
63+
class_counts = self._count_class_examples(dataset, split)
64+
max_count = max(class_counts.values())
65+
target_count = self.max_samples if self.max_samples is not None else max_count
66+
logger.debug("Target count per class: %s", target_count)
67+
for class_id, current_count in class_counts.items():
68+
if current_count < target_count:
69+
needed = target_count - current_count
70+
self._augment_class(dataset, split, class_id, needed, batch_size)
71+
72+
return dataset
73+
74+
def _count_class_examples(self, dataset: Dataset, split: str) -> dict[int, int]:
75+
"""Count the number of examples for each class."""
76+
counts: dict[int, int] = defaultdict(int)
77+
for sample in dataset[split]:
78+
counts[sample[Dataset.label_feature]] += 1
79+
return counts
80+
81+
def _augment_class(self, dataset: Dataset, split: str, class_id: int, needed: int, batch_size: int) -> None:
82+
"""Generate additional examples for the class."""
83+
intent = next(i for i in dataset.intents if i.id == class_id)
84+
class_name = getattr(intent, "name", f"class_{class_id}")
85+
logger.debug("Starting augmentation for class %s (%s)", class_id, class_name)
86+
logger.debug("Initial samples: %s", len([s for s in dataset[split] if s[Dataset.label_feature] == class_id]))
87+
logger.debug("Target needed: %s samples", needed)
88+
89+
class_samples = [s for s in dataset[split] if s[Dataset.label_feature] == class_id]
90+
if not class_samples:
91+
msg = f"No samples for class {class_id}"
92+
raise ValueError(msg)
93+
94+
generated_utterances: list[str] = []
95+
max_attempts = 5
96+
attempts = 0
97+
98+
while len(generated_utterances) < needed and attempts < max_attempts:
99+
current_needed = needed - len(generated_utterances)
100+
current_batch = min(batch_size, current_needed)
101+
logger.debug("Attempt %s: Generating %s utterances for class %s", attempts + 1, current_batch, class_id)
102+
103+
new_utterances = self.utterance_generator(intent_data=intent, n_generations=current_batch)
104+
105+
valid_utterances = self._process_utterances(new_utterances)
106+
for ut in valid_utterances:
107+
if ut and isinstance(ut, str):
108+
generated_utterances.append(ut)
109+
if len(generated_utterances) >= needed:
110+
break
111+
112+
logger.debug("Generated %s valid utterances in this attempt", len(valid_utterances))
113+
logger.debug(
114+
"Progress: %s/%s (%s%%)",
115+
len(generated_utterances),
116+
needed,
117+
min(100, int(len(generated_utterances) / needed * 100)),
118+
)
119+
120+
attempts += 1
121+
122+
if len(generated_utterances) < needed:
123+
logger.debug(
124+
"Warning: Could only generate %s/%s utterances after %s attempts",
125+
len(generated_utterances),
126+
needed,
127+
max_attempts,
128+
)
129+
130+
generated_utterances = generated_utterances[:needed]
131+
132+
new_samples = []
133+
for utterance in generated_utterances:
134+
new_sample = {Dataset.utterance_feature: utterance, Dataset.label_feature: class_id}
135+
new_samples.append(new_sample)
136+
137+
updated_data = list(dataset[split]) + new_samples
138+
dataset[split] = HFDataset.from_list(updated_data)
139+
140+
final_count = len([s for s in dataset[split] if s[Dataset.label_feature] == class_id])
141+
logger.debug("Completed augmentation for class %s (%s)", class_id, class_name)
142+
logger.debug("Total samples after augmentation: %s", final_count)
143+
144+
def _process_utterances(self, generated: list[str]) -> list[str]:
145+
"""Process and clean generated utterances."""
146+
processed = []
147+
for ut in generated:
148+
if "', '" in ut or "',\n" in ut:
149+
clean_ut = ut.replace("[", "").replace("]", "").replace("'", "")
150+
split_ut = [u.strip() for u in clean_ut.split(", ") if u.strip()]
151+
processed.extend(split_ut)
152+
else:
153+
processed.append(ut.strip())
154+
return processed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import logging
2+
import os
3+
from collections import defaultdict
4+
from unittest.mock import AsyncMock, Mock, patch
5+
6+
import pytest
7+
8+
from autointent import Dataset
9+
from autointent.custom_types import Split
10+
from autointent.generation.utterances import DatasetBalancer, Generator
11+
from autointent.generation.utterances.basic.chat_templates._synthesizer_en import EnglishSynthesizerTemplate
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
@pytest.fixture
17+
def mock_generator():
18+
generator = Mock(spec=Generator)
19+
generator.get_chat_completion.return_value = "test_utterance"
20+
generator.get_chat_completion_async = AsyncMock(return_value="test_utterance")
21+
return generator
22+
23+
24+
@pytest.fixture
25+
def mock_prompt_maker():
26+
return Mock(return_value=[Mock()])
27+
28+
29+
@pytest.fixture
30+
def unbalanced_dataset():
31+
return Dataset.from_dict(
32+
{
33+
"intents": [{"id": 0, "name": "A"}, {"id": 1, "name": "B"}],
34+
"train": [
35+
{"utterance": "test a1", "label": 0},
36+
{"utterance": "test a2", "label": 0},
37+
{"utterance": "test b1", "label": 1},
38+
],
39+
}
40+
)
41+
42+
43+
def test_balancer(unbalanced_dataset, mock_generator, mock_prompt_maker):
44+
balancer = DatasetBalancer(generator=mock_generator, prompt_maker=mock_prompt_maker)
45+
logger.info("Before balancing:")
46+
for sample in unbalanced_dataset[Split.TRAIN]:
47+
logger.info("Utterance: %s, Label: %s", sample["utterance"], sample["label"])
48+
49+
with patch.object(balancer.utterance_generator, "__call__") as mock_call:
50+
mock_call.return_value = ["generated_utterance"]
51+
52+
balanced = balancer.balance(unbalanced_dataset)
53+
54+
logger.info("After balancing:")
55+
for sample in balanced[Split.TRAIN]:
56+
logger.info("Utterance: %s, Label: %s", sample["utterance"], sample["label"])
57+
58+
labels = [s["label"] for s in balanced[Split.TRAIN]]
59+
assert labels.count(0) == 2, "Class 0 should not change"
60+
assert labels.count(1) == 2, "Class 1 should increase to 2"
61+
assert len(labels) == 4, "The total number of examples should be 4"
62+
63+
original_utterances = {s["utterance"] for s in unbalanced_dataset[Split.TRAIN]}
64+
balanced_utterances = {s["utterance"] for s in balanced[Split.TRAIN]}
65+
assert original_utterances.issubset(balanced_utterances)
66+
67+
68+
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Requires OpenAI API key in environment")
69+
def test_real_balancer():
70+
test_data = {
71+
"intents": [{"id": 0, "name": "Book restaurant"}, {"id": 1, "name": "Check weather"}],
72+
"train": [
73+
{"utterance": "Book a table for two", "label": 0},
74+
{"utterance": "Reserve a table", "label": 0},
75+
{"utterance": "What's the weather in Moscow?", "label": 1},
76+
],
77+
}
78+
dataset = Dataset.from_dict(test_data)
79+
template = EnglishSynthesizerTemplate(dataset, split="train")
80+
generator = Generator()
81+
evolutions = template
82+
balancer = DatasetBalancer(generator=generator, prompt_maker=evolutions, max_samples_per_class=3, async_mode=False)
83+
84+
logger.info("Starting balance process...")
85+
balanced = balancer.balance(dataset)
86+
87+
class_counts = defaultdict(int)
88+
for sample in balanced[Split.TRAIN]:
89+
class_counts[sample["label"]] += 1
90+
91+
logger.info("Balancing results:")
92+
logger.info("Class 0 count: %s", class_counts[0])
93+
logger.info("Class 1 count: %s", class_counts[1])
94+
logger.info("Generated examples:")
95+
for sample in balanced[Split.TRAIN]:
96+
if sample["utterance"] not in {s["utterance"] for s in test_data["train"]}:
97+
logger.info("[Class %s]: %s", sample["label"], sample["utterance"])
98+
99+
assert class_counts[0] == 3, "Class 0 should have 3 examples"
100+
assert class_counts[1] == 3, "Class 1 should have 3 examples"
101+
assert len(balanced[Split.TRAIN]) == 6, "Total examples should be 6"

0 commit comments

Comments
 (0)