Skip to content

Commit d4e8992

Browse files
committed
feat: batch_size in basic augmentation
1 parent fb18058 commit d4e8992

File tree

3 files changed

+92
-18
lines changed

3 files changed

+92
-18
lines changed

autointent/generation/utterances/basic/utterance_generator.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ class UtteranceGenerator:
2121
punctuation, and length of the desired generations.
2222
"""
2323

24-
def __init__(self,
25-
generator: Generator,
26-
prompt_maker: Callable[[Intent, int], list[Message]],
27-
async_mode: bool = False
28-
) -> None:
24+
def __init__(
25+
self,
26+
generator: Generator,
27+
prompt_maker: Callable[[Intent, int], list[Message]],
28+
async_mode: bool = False
29+
) -> None:
2930
"""Initialize."""
3031
self.generator = generator
3132
self.prompt_maker = prompt_maker
@@ -49,27 +50,33 @@ def augment(
4950
split_name: str = Split.TRAIN,
5051
n_generations: int = 5,
5152
update_split: bool = True,
53+
batch_size: int | None = None
5254
) -> list[Sample]:
5355
"""
5456
Augment some split of dataset.
5557
56-
Note that for now it supports only single-label datasets.
58+
:param dataset: Dataset object
59+
:param split_name: Dataset split (default is TRAIN)
60+
:param n_generations: Number of utterances to generate per intent
61+
:param update_split: Whether to update the dataset split
62+
:param batch_size: Batch size for async generation (None means all at once)
63+
:return: List of generated samples
5764
"""
5865
if self.async_mode:
59-
return asyncio.run(self._augment_async(dataset, split_name, n_generations, update_split))
66+
return asyncio.run(self._augment_async(dataset, split_name, n_generations, update_split, batch_size))
67+
6068
original_split = dataset[split_name]
6169
new_samples = []
6270
for intent in dataset.intents:
63-
generated_utterances = self(
64-
intent_data=intent,
65-
n_generations=n_generations,
66-
)
71+
generated_utterances = self(intent_data=intent, n_generations=n_generations)
6772
new_samples.extend(
6873
[{Dataset.label_feature: intent.id, Dataset.utterance_feature: ut} for ut in generated_utterances]
6974
)
75+
7076
if update_split:
7177
generated_split = HFDataset.from_list(new_samples)
7278
dataset[split_name] = concatenate_datasets([original_split, generated_split])
79+
7380
return [Sample(**sample) for sample in new_samples]
7481

7582
async def _augment_async(
@@ -78,19 +85,32 @@ async def _augment_async(
7885
split_name: str = Split.TRAIN,
7986
n_generations: int = 5,
8087
update_split: bool = True,
88+
batch_size: int | None = None
8189
) -> list[Sample]:
8290
"""
83-
Augment some split of dataset asynchronously.
84-
85-
Note that for now it supports only single-label datasets.
91+
Augment some split of dataset asynchronously in batches.
92+
93+
:param dataset: Dataset object
94+
:param split_name: Dataset split (default is TRAIN)
95+
:param n_generations: Number of utterances to generate per intent
96+
:param update_split: Whether to update the dataset split
97+
:param batch_size: Batch size for async generation (None means all at once)
98+
:return: List of generated samples
8699
"""
87100
original_split = dataset[split_name]
88101
new_samples = []
89-
tasks = []
90102

91-
tasks = [self._call_async(intent_data=intent, n_generations=n_generations) for intent in dataset.intents]
103+
if not batch_size:
104+
tasks = [self._call_async(intent_data=intent, n_generations=n_generations) for intent in dataset.intents]
105+
results = await asyncio.gather(*tasks)
92106

93-
results = await asyncio.gather(*tasks)
107+
else:
108+
results = []
109+
for start_idx in range(0, len(dataset.intents), batch_size):
110+
batch_intents = dataset.intents[start_idx:start_idx + batch_size]
111+
tasks = [self._call_async(intent_data=intent, n_generations=n_generations) for intent in batch_intents]
112+
batch_results = await asyncio.gather(*tasks)
113+
results.extend(batch_results)
94114

95115
for i, generated_utterances in enumerate(results):
96116
intent = dataset.intents[i]
@@ -113,4 +133,4 @@ def _extract_utterances(response_text: str) -> list[str]:
113133
"""
114134
raw_utterances = response_text.split("\n")
115135
# remove enumeration
116-
return [ut[ut.find(" ") + 1 :] for ut in raw_utterances]
136+
return [ut[ut.find(" ") + 1:] if " " in ut else ut for ut in raw_utterances]

tests/generation/utterances/test_basic_synthesizer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,24 @@ def test_on_dataset_async(dataset):
7474

7575
assert n_before + len(new_samples) == n_after
7676
assert len(new_samples) == len(dataset.intents)
77+
78+
def test_on_dataset_async_with_batch_size(dataset):
79+
mock_llm = AsyncMock()
80+
mock_llm.get_chat_completion_async.return_value = "1. LLM answer"
81+
82+
split_name = "train_0"
83+
84+
template = SynthesizerChatTemplate(dataset, split=split_name)
85+
augmenter = UtteranceGenerator(mock_llm, template, async_mode=True)
86+
87+
batch_size = 2
88+
new_samples = augmenter.augment(dataset, split_name=split_name, update_split=False, batch_size=batch_size)
89+
90+
assert len(new_samples) == len(dataset.intents)
91+
assert all(sample.utterance == "LLM answer" for sample in new_samples)
92+
93+
batch_size = len(dataset.intents) + 5
94+
new_samples = augmenter.augment(dataset, split_name=split_name, update_split=False, batch_size=batch_size)
95+
96+
assert len(new_samples) == len(dataset.intents)
97+
assert all(sample.utterance == "LLM answer" for sample in new_samples)

tests/generation/utterances/test_evolver.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,36 @@ def test_on_dataset_evolver_async(dataset):
5656

5757
assert n_before + len(new_samples) == n_after
5858
assert len(new_samples) == n_before
59+
60+
def test_on_dataset_evolver_async_with_batch_size(dataset):
61+
mock_llm = AsyncMock()
62+
mock_llm.get_chat_completion_async.return_value = "LLM answer"
63+
64+
split_name = "train_0"
65+
66+
template = AbstractEvolution()
67+
augmenter = UtteranceEvolver(mock_llm, [template], async_mode=True)
68+
69+
batch_size = 2
70+
new_samples = augmenter.augment(
71+
dataset,
72+
split_name=split_name,
73+
n_evolutions=1,
74+
update_split=False,
75+
batch_size=batch_size
76+
)
77+
78+
assert len(new_samples) == len(dataset[split_name])
79+
assert all(sample.utterance == "LLM answer" for sample in new_samples)
80+
81+
batch_size = len(dataset[split_name]) + 5
82+
new_samples = augmenter.augment(
83+
dataset,
84+
split_name=split_name,
85+
n_evolutions=1,
86+
update_split=False,
87+
batch_size=batch_size
88+
)
89+
90+
assert len(new_samples) == len(dataset[split_name])
91+
assert all(sample.utterance == "LLM answer" for sample in new_samples)

0 commit comments

Comments
 (0)