Skip to content

Commit 9cf6f64

Browse files
committed
first working tested balancer
1 parent b3450a3 commit 9cf6f64

File tree

2 files changed

+92
-96
lines changed

2 files changed

+92
-96
lines changed
Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,47 @@
11
"""Module for balancing datasets through augmentation of underrepresented classes."""
22

33
from collections import defaultdict
4-
from typing import List
4+
from collections.abc import Callable
55

66
from autointent import Dataset
77
from autointent.custom_types import Split
8-
from autointent.generation.utterances.evolution.evolver import UtteranceEvolver
9-
from autointent.generation.utterances.generator import Generator
108
from autointent.generation.utterances.basic.utterance_generator import UtteranceGenerator
9+
from autointent.generation.utterances.generator import Generator
10+
from autointent.generation.utterances.schemas import Message
11+
from autointent.schemas import Intent
1112

1213

1314
class DatasetBalancer:
1415
"""Class for balancing dataset through example augmentation."""
1516

16-
class DatasetBalancer:
1717
def __init__(
1818
self,
1919
generator: Generator,
20-
evolutions: List,
21-
seed: int = 42,
20+
prompt_maker: Callable[[Intent, int], list[Message]],
2221
async_mode: bool = False,
2322
max_samples_per_class: int | None = None,
2423
) -> None:
25-
if not isinstance(generator, Generator):
26-
raise TypeError("Generator must be an instance of autointent.generation.utterances.generator.Generator")
27-
28-
if not isinstance(evolutions, list) or not all(callable(e) for e in evolutions):
29-
raise TypeError("Evolutions must be a list of callable objects")
30-
24+
"""
25+
Initialize the UtteranceBalancer.
26+
27+
Args:
28+
generator (Generator): The generator object used to create utterances.
29+
prompt_maker (Callable[[Intent, int], list[Message]]): A callable that creates prompts for the generator.
30+
seed (int, optional): The seed for random number generation. Defaults to 42.
31+
async_mode (bool, optional): Whether to run the generator in asynchronous mode. Defaults to False.
32+
max_samples_per_class (int | None, optional): The maximum number of samples per class. Must be a positive integer or None. Defaults to None.
33+
Raises:
34+
ValueError: If max_samples_per_class is not None and is less than or equal to 0.
35+
"""
3136
if max_samples_per_class is not None and max_samples_per_class <= 0:
32-
raise ValueError("max_samples_per_class must be a positive integer or None")
33-
34-
self.evolver = UtteranceGenerator(generator, evolutions, async_mode)
35-
self.max_samples = max_samples_per_class
37+
msg = "max_samples_per_class must be a positive integer or None"
38+
raise ValueError(msg)
3639

40+
self.evolver = UtteranceGenerator(generator=generator, prompt_maker=prompt_maker, async_mode=async_mode)
41+
self.max_samples = max_samples_per_class
3742

3843
def balance(
39-
self, dataset: Dataset, split: str = Split.TRAIN, n_evolutions: int = 3, batch_size: int = 4
44+
self, dataset: Dataset, split: str = Split.TRAIN, batch_size: int = 4
4045
) -> Dataset:
4146
"""
4247
Balances the specified dataset split.
@@ -54,12 +59,11 @@ def balance(
5459
class_counts = self._count_class_examples(dataset, split)
5560
max_count = max(class_counts.values())
5661
target_count = self.max_samples if self.max_samples is not None else max_count
57-
print(f"Target count per class: {target_count}") # Добавить логирование
58-
62+
print(f"Target count per class: {target_count}")
5963
for class_id, current_count in class_counts.items():
6064
if current_count < target_count:
6165
needed = target_count - current_count
62-
self._augment_class(dataset, split, class_id, needed, n_evolutions, batch_size)
66+
self._augment_class(dataset, split, class_id, needed, batch_size)
6367

6468
return dataset
6569

@@ -71,13 +75,13 @@ def _count_class_examples(self, dataset: Dataset, split: str) -> dict[int, int]:
7175
return counts
7276

7377
def _augment_class(
74-
self, dataset: Dataset, split: str, class_id: int, needed: int, n_evolutions: int, batch_size: int
78+
self, dataset: Dataset, split: str, class_id: int, needed: int, batch_size: int
7579
) -> None:
7680
"""Generate additional examples for the class."""
7781
print("\n📂 DATASET BEFORE AUGMENTATION:")
7882
self._print_dataset(dataset, split)
7983
intent = next(i for i in dataset.intents if i.id == class_id)
80-
class_name = getattr(intent, 'name', f'class_{class_id}') # Получаем имя класса, если доступно
84+
class_name = getattr(intent, "name", f"class_{class_id}")
8185
print(f"\n🚀 Starting augmentation for class {class_id} ({class_name})")
8286
print(f"📊 Initial samples: {len([s for s in dataset[split] if s[Dataset.label_feature] == class_id])}")
8387
print(f"🎯 Target needed: {needed} samples")
@@ -92,7 +96,7 @@ def _augment_class(
9296

9397
while total_generated < needed:
9498
print(f"\n🔄 Batch generation: {per_sample_evolutions} evolutions per sample")
95-
99+
96100
generated = self.evolver.augment(
97101
dataset, split_name=split, n_generations=per_sample_evolutions, update_split=True, batch_size=batch_size
98102
)
@@ -101,10 +105,10 @@ def _augment_class(
101105
print(f"✅ Generated {len(generated)} examples")
102106
if generated:
103107
print("🔠 Example generated utterances:")
104-
for i, example in enumerate(generated[:3]):
108+
for i, example in enumerate(generated[:3]):
105109
utterance = getattr(example, Dataset.utterance_feature, str(example))
106-
print(f" {i+1}. {utterance[:60]}...")
107-
110+
print(f" {i+1}. {utterance[:60]}...")
111+
108112
total_generated += len(generated)
109113
print(f"📈 Progress: {total_generated}/{needed} ({min(100, int(total_generated/needed*100))}%)")
110114

@@ -119,7 +123,6 @@ def _augment_class(
119123
print("\n📦 DATASET AFTER AUGMENTATION:")
120124
self._print_dataset(dataset, split)
121125
print("━" * 50)
122-
123126

124127
def _remove_extra_samples(self, dataset: Dataset, split: str, class_id: int, extra: int) -> None:
125128
"""Remove extra examples of the class."""
@@ -128,13 +131,14 @@ def _remove_extra_samples(self, dataset: Dataset, split: str, class_id: int, ext
128131

129132
new_data = [s for i, s in enumerate(dataset[split]) if i not in indices_to_remove]
130133
dataset[split] = dataset[split].from_list(new_data)
134+
131135
def _print_dataset(self, dataset: Dataset, split: str) -> None:
132-
"""Helper method to print dataset in readable format"""
133-
print(f"Split: {split}")
134-
print("-" * 50)
135-
for i, sample in enumerate(dataset[split]):
136-
label = sample[Dataset.label_feature]
137-
text = sample[Dataset.utterance_feature]
138-
print(f"{i+1:3d} | {label:15} | {text[:50]:<50}...")
139-
print("-" * 50)
140-
print(f"Total samples: {len(dataset[split])}\n")
136+
"""Print the dataset in a readable format."""
137+
print(f"Split: {split}")
138+
print("-" * 50)
139+
for i, sample in enumerate(dataset[split]):
140+
label = sample[Dataset.label_feature]
141+
text = sample[Dataset.utterance_feature]
142+
print(f"{i+1:3d} | {label:15} | {text[:50]:<50}...")
143+
print("-" * 50)
144+
print(f"Total samples: {len(dataset[split])}\n")
Lines changed: 52 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,117 +1,109 @@
1-
from collections import defaultdict
1+
import os
2+
from collections import defaultdict
23
from unittest.mock import AsyncMock, Mock, patch
34

45
import pytest
5-
import os
6-
from datasets import Dataset as HFDataset
6+
from datasets import Dataset as HFDataset
77

88
from autointent import Dataset
9-
from autointent.custom_types import Split
10-
from autointent.generation.utterances import AbstractEvolution, DatasetBalancer, Generator
11-
from autointent.schemas import Intent
9+
from autointent.custom_types import Split
10+
from autointent.generation.utterances import DatasetBalancer, Generator
11+
from autointent.generation.utterances.basic.chat_template import SynthesizerChatTemplate
12+
from autointent.schemas import Sample
1213

1314

1415
@pytest.fixture
1516
def mock_generator():
16-
generator = Mock(spec=Generator)
17+
generator = Mock(spec=Generator)
1718
generator.get_chat_completion.return_value = "test_utterance"
1819
generator.get_chat_completion_async = AsyncMock(return_value="test_utterance")
1920
return generator
2021

22+
2123
@pytest.fixture
22-
def mock_evolutions():
23-
return [Mock(side_effect=lambda x, y: []), Mock(side_effect=lambda x, y: [])]
24+
def mock_prompt_maker():
25+
return Mock(return_value=[Mock()])
2426

2527

2628
@pytest.fixture
2729
def unbalanced_dataset():
28-
return Dataset.from_dict({
29-
"intents": [{"id": 0, "name": "A"}, {"id": 1, "name": "B"}],
30-
"train": [
31-
{"utterance": "test a1", "label": 0},
32-
{"utterance": "test a2", "label": 0},
33-
{"utterance": "test b1", "label": 1},
34-
]
35-
})
30+
return Dataset.from_dict(
31+
{
32+
"intents": [{"id": 0, "name": "A"}, {"id": 1, "name": "B"}],
33+
"train": [
34+
{"utterance": "test a1", "label": 0},
35+
{"utterance": "test a2", "label": 0},
36+
{"utterance": "test b1", "label": 1},
37+
],
38+
}
39+
)
3640

3741

38-
def test_balancer(unbalanced_dataset, mock_generator, mock_evolutions):
39-
40-
balancer = DatasetBalancer(mock_generator, mock_evolutions)
41-
42+
def test_balancer(unbalanced_dataset, mock_generator, mock_prompt_maker):
43+
balancer = DatasetBalancer(generator=mock_generator, prompt_maker=mock_prompt_maker)
4244
print("\nBefore balancing:")
4345
for sample in unbalanced_dataset[Split.TRAIN]:
4446
print(f"Utterance: {sample['utterance']}, Label: {sample['label']}")
45-
46-
with patch.object(balancer.evolver, 'augment') as mock_augment:
47-
def augment_side_effect(dataset, split_name, n_evolutions, update_split, batch_size):
47+
48+
with patch.object(balancer.evolver, "augment") as mock_augment:
49+
50+
def augment_side_effect(dataset, split_name, n_generations, update_split, batch_size):
51+
new_sample = {"utterance": "generated_utterance", "label": 1}
4852
if update_split:
49-
new_sample = {"utterance": "generated_utterance", "label": 1}
5053
current_data = dataset[split_name].to_list()
5154
current_data.append(new_sample)
5255
dataset[split_name] = HFDataset.from_list(current_data)
53-
return [new_sample]
54-
56+
return [Sample(**new_sample)]
57+
5558
mock_augment.side_effect = augment_side_effect
56-
59+
5760
balanced = balancer.balance(unbalanced_dataset)
58-
61+
5962
print("\nAfter balancing:")
6063
for sample in balanced[Split.TRAIN]:
6164
print(f"Utterance: {sample['utterance']}, Label: {sample['label']}")
62-
65+
6366
labels = [s["label"] for s in balanced[Split.TRAIN]]
64-
assert labels.count(0) == 2, "Класс 0 не должен изменяться"
65-
assert labels.count(1) == 2, "Класс 1 должен увеличиться до 2"
66-
assert len(labels) == 4, "Общее количество примеров должно быть 4"
67-
67+
assert labels.count(0) == 2, "Class 0 should not change"
68+
assert labels.count(1) == 2, "Class 1 should increase to 2"
69+
assert len(labels) == 4, "The total number of examples should be 4"
70+
6871
original_utterances = {s["utterance"] for s in unbalanced_dataset[Split.TRAIN]}
6972
balanced_utterances = {s["utterance"] for s in balanced[Split.TRAIN]}
7073
assert original_utterances.issubset(balanced_utterances)
7174

72-
@pytest.mark.integration
73-
@pytest.mark.skipif(
74-
not os.getenv("OPENAI_API_KEY"),
75-
reason="Requires OpenAI API key in environment"
76-
)
75+
76+
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Requires OpenAI API key in environment")
7777
def test_real_balancer():
7878
test_data = {
79-
"intents": [
80-
{"id": 0, "name": "Book restaurant"},
81-
{"id": 1, "name": "Check weather"}
82-
],
79+
"intents": [{"id": 0, "name": "Book restaurant"}, {"id": 1, "name": "Check weather"}],
8380
"train": [
8481
{"utterance": "Book a table for two", "label": 0},
85-
{"utterance": "Reserve a table", "label": 0}, # Добавлен второй пример
86-
87-
{"utterance": "What's the weather in Moscow?", "label": 1}
88-
]
82+
{"utterance": "Reserve a table", "label": 0},
83+
{"utterance": "What's the weather in Moscow?", "label": 1},
84+
],
8985
}
9086
dataset = Dataset.from_dict(test_data)
87+
template = SynthesizerChatTemplate(dataset, split="train")
88+
generator = Generator()
89+
evolutions = template
90+
balancer = DatasetBalancer(generator=generator, prompt_maker=evolutions, max_samples_per_class=3, async_mode=False)
9191

92-
evolutions = [AbstractEvolution()]
93-
balancer = DatasetBalancer(
94-
generator=Generator(),
95-
evolutions=evolutions,
96-
max_samples_per_class=3,
97-
async_mode=False
98-
)
99-
10092
print("\nStarting balance process...")
101-
balanced = balancer.balance(dataset, n_evolutions=1)
102-
93+
balanced = balancer.balance(dataset)
94+
10395
class_counts = defaultdict(int)
10496
for sample in balanced[Split.TRAIN]:
10597
class_counts[sample["label"]] += 1
106-
98+
10799
print("\nBalancing results:")
108100
print(f"Class 0 count: {class_counts[0]}")
109101
print(f"Class 1 count: {class_counts[1]}")
110102
print("\nGenerated examples:")
111103
for sample in balanced[Split.TRAIN]:
112104
if sample["utterance"] not in {s["utterance"] for s in test_data["train"]}:
113105
print(f"[Class {sample['label']}]: {sample['utterance']}")
114-
106+
115107
assert class_counts[0] == 3, "Class 0 should have 3 examples"
116108
assert class_counts[1] == 3, "Class 1 should have 3 examples"
117-
assert len(balanced[Split.TRAIN]) == 6, "Total examples should be 6"
109+
assert len(balanced[Split.TRAIN]) == 6, "Total examples should be 6"

0 commit comments

Comments
 (0)