Skip to content

Commit 8b8abef

Browse files
authored
Feat/async augmentations (#116)
* fix: change location of prompts * feat: update evolution class * feat: update chat templates * feat: added new evolutions * feat: change test structure & add test_generator * fix: fixex ruff * fix: fixed import * fix: fixed chat_templates * feat: async evolutions * feat: async basic * feat: async basic * feat: added tests * feat: tests for generator * fix: main * feat: batch_size in basic augmentation * fixed error * fixed error * feat: default batch size * feat: batch size with requests not prompts for evolver * fix: deleted call_async * fix: fixed lint * fix: tests
1 parent 166e94a commit 8b8abef

File tree

8 files changed

+287
-20
lines changed

8 files changed

+287
-20
lines changed

autointent/generation/utterances/basic/cli.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
def main() -> None:
17-
"""ClI endpoint."""
17+
"""CLI endpoint."""
1818
parser = ArgumentParser()
1919
parser.add_argument(
2020
"--input-path",
@@ -48,11 +48,12 @@ def main() -> None:
4848
default=5,
4949
help="Number of utterances to use as an example for augmentation",
5050
)
51+
parser.add_argument("--async-mode", action="store_true", help="Enable asynchronous generation")
5152
args = parser.parse_args()
5253

5354
dataset = load_dataset(args.input_path)
5455
template = SynthesizerChatTemplate(dataset, args.split, max_sample_utterances=args.n_sample_utterances)
55-
generator = UtteranceGenerator(Generator(), template)
56+
generator = UtteranceGenerator(Generator(), template, async_mode=args.async_mode)
5657

5758
n_before = len(dataset[args.split])
5859
new_samples = generator.augment(dataset, split_name=args.split, n_generations=args.n_generations)

autointent/generation/utterances/basic/utterance_generator.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Basic generation of new utterances from existing ones."""
22

3+
import asyncio
34
from collections.abc import Callable
45

56
from datasets import Dataset as HFDataset
@@ -17,46 +18,103 @@ class UtteranceGenerator:
1718
Basic generation of new utterances from existing ones.
1819
1920
This augmentation method simply prompts LLM to look at existing examples
20-
and generate similar. Additionaly it can consider some aspects of style,
21-
punctuation and length of the desired generations.
21+
and generate similar. Additionally, it can consider some aspects of style,
22+
punctuation, and length of the desired generations.
2223
"""
2324

24-
def __init__(self, generator: Generator, prompt_maker: Callable[[Intent, int], list[Message]]) -> None:
25+
def __init__(
26+
self, generator: Generator, prompt_maker: Callable[[Intent, int], list[Message]], async_mode: bool = False
27+
) -> None:
2528
"""Initialize."""
2629
self.generator = generator
2730
self.prompt_maker = prompt_maker
31+
self.async_mode = async_mode
2832

2933
def __call__(self, intent_data: Intent, n_generations: int) -> list[str]:
3034
"""Generate new utterances."""
3135
messages = self.prompt_maker(intent_data, n_generations)
3236
response_text = self.generator.get_chat_completion(messages)
3337
return _extract_utterances(response_text)
3438

39+
async def _call_async(self, intent_data: Intent, n_generations: int) -> list[str]:
40+
"""Generate new utterances asynchronously."""
41+
messages = self.prompt_maker(intent_data, n_generations)
42+
response_text = await self.generator.get_chat_completion_async(messages)
43+
return _extract_utterances(response_text)
44+
3545
def augment(
3646
self,
3747
dataset: Dataset,
3848
split_name: str = Split.TRAIN,
3949
n_generations: int = 5,
4050
update_split: bool = True,
51+
batch_size: int = 4,
4152
) -> list[Sample]:
4253
"""
4354
Augment some split of dataset.
4455
45-
TODO Note that for now it supports only single-label datasets.
56+
:param dataset: Dataset object
57+
:param split_name: Dataset split (default is TRAIN)
58+
:param n_generations: Number of utterances to generate per intent
59+
:param update_split: Whether to update the dataset split
60+
:param batch_size: Batch size for async generation
61+
:return: List of generated samples
4662
"""
63+
if self.async_mode:
64+
return asyncio.run(self._augment_async(dataset, split_name, n_generations, update_split, batch_size))
65+
4766
original_split = dataset[split_name]
4867
new_samples = []
4968
for intent in dataset.intents:
50-
generated_utterances = self(
51-
intent_data=intent,
52-
n_generations=n_generations,
69+
generated_utterances = self(intent_data=intent, n_generations=n_generations)
70+
new_samples.extend(
71+
[{Dataset.label_feature: intent.id, Dataset.utterance_feature: ut} for ut in generated_utterances]
5372
)
73+
74+
if update_split:
75+
generated_split = HFDataset.from_list(new_samples)
76+
dataset[split_name] = concatenate_datasets([original_split, generated_split])
77+
78+
return [Sample(**sample) for sample in new_samples]
79+
80+
async def _augment_async(
81+
self,
82+
dataset: Dataset,
83+
split_name: str = Split.TRAIN,
84+
n_generations: int = 5,
85+
update_split: bool = True,
86+
batch_size: int = 4,
87+
) -> list[Sample]:
88+
"""
89+
Augment some split of dataset asynchronously in batches.
90+
91+
:param dataset: Dataset object
92+
:param split_name: Dataset split (default is TRAIN)
93+
:param n_generations: Number of utterances to generate per intent
94+
:param update_split: Whether to update the dataset split
95+
:param batch_size: Batch size for async generation
96+
:return: List of generated samples
97+
"""
98+
original_split = dataset[split_name]
99+
new_samples = []
100+
101+
results = []
102+
for start_idx in range(0, len(dataset.intents), batch_size):
103+
batch_intents = dataset.intents[start_idx : start_idx + batch_size]
104+
tasks = [self._call_async(intent_data=intent, n_generations=n_generations) for intent in batch_intents]
105+
batch_results = await asyncio.gather(*tasks)
106+
results.extend(batch_results)
107+
108+
for i, generated_utterances in enumerate(results):
109+
intent = dataset.intents[i]
54110
new_samples.extend(
55111
[{Dataset.label_feature: intent.id, Dataset.utterance_feature: ut} for ut in generated_utterances]
56112
)
113+
57114
if update_split:
58115
generated_split = HFDataset.from_list(new_samples)
59116
dataset[split_name] = concatenate_datasets([original_split, generated_split])
117+
60118
return [Sample(**sample) for sample in new_samples]
61119

62120

@@ -68,4 +126,4 @@ def _extract_utterances(response_text: str) -> list[str]:
68126
"""
69127
raw_utterances = response_text.split("\n")
70128
# remove enumeration
71-
return [ut[ut.find(" ") + 1 :] for ut in raw_utterances]
129+
return [ut[ut.find(" ") + 1 :] if " " in ut else ut for ut in raw_utterances]

autointent/generation/utterances/evolution/cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def main() -> None:
5353
parser.add_argument("--funny", action="store_true", help="Whether to use `Funny` evolution")
5454
parser.add_argument("--goofy", action="store_true", help="Whether to use `Goofy` evolution")
5555
parser.add_argument("--informal", action="store_true", help="Whether to use `Informal` evolution")
56+
parser.add_argument("--async-mode", action="store_true", help="Enable asynchronous generation")
5657
parser.add_argument("--seed", type=int, default=0)
5758

5859
args = parser.parse_args()
@@ -80,7 +81,7 @@ def main() -> None:
8081
dataset = load_dataset(args.input_path)
8182
n_before = len(dataset[args.split])
8283

83-
generator = UtteranceEvolver(Generator(), evolutions, args.seed)
84+
generator = UtteranceEvolver(Generator(), evolutions, args.seed, async_mode=args.async_mode)
8485
new_samples = generator.augment(dataset, split_name=args.split, n_evolutions=args.n_evolutions)
8586
n_after = len(dataset[args.split])
8687

autointent/generation/utterances/evolution/evolver.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Deeply inspired by DeepEval evolutions.
55
"""
66

7+
import asyncio
78
import random
89
from collections.abc import Callable, Sequence
910

@@ -26,31 +27,58 @@ class UtteranceEvolver:
2627
"""
2728

2829
def __init__(
29-
self, generator: Generator, prompt_makers: Sequence[Callable[[str, Intent], list[Message]]], seed: int = 0
30+
self,
31+
generator: Generator,
32+
prompt_makers: Sequence[Callable[[str, Intent], list[Message]]],
33+
seed: int = 0,
34+
async_mode: bool = False,
3035
) -> None:
3136
"""Initialize."""
3237
self.generator = generator
3338
self.prompt_makers = prompt_makers
39+
self.async_mode = async_mode
3440
random.seed(seed)
3541

3642
def _evolve(self, utterance: str, intent_data: Intent) -> str:
37-
"""Apply evolutions single time."""
43+
"""Apply evolutions single time synchronously."""
3844
maker = random.choice(self.prompt_makers)
3945
chat = maker(utterance, intent_data)
4046
return self.generator.get_chat_completion(chat)
4147

48+
async def _evolve_async(self, utterance: str, intent_data: Intent) -> str:
49+
"""Apply evolutions a single time (asynchronously)."""
50+
maker = random.choice(self.prompt_makers)
51+
chat = maker(utterance, intent_data)
52+
return await self.generator.get_chat_completion_async(chat)
53+
4254
def __call__(self, utterance: str, intent_data: Intent, n_evolutions: int = 1) -> list[str]:
43-
"""Apply evolutions mupltiple times."""
55+
"""Apply evolutions multiple times (synchronously)."""
4456
return [self._evolve(utterance, intent_data) for _ in range(n_evolutions)]
4557

4658
def augment(
47-
self, dataset: Dataset, split_name: str = Split.TRAIN, n_evolutions: int = 1, update_split: bool = True
59+
self,
60+
dataset: Dataset,
61+
split_name: str = Split.TRAIN,
62+
n_evolutions: int = 1,
63+
update_split: bool = True,
64+
batch_size: int = 4,
4865
) -> list[Sample]:
4966
"""
5067
Augment some split of dataset.
5168
5269
Note that for now it supports only single-label datasets.
5370
"""
71+
if self.async_mode:
72+
return asyncio.run(
73+
self._augment_async(
74+
dataset=dataset,
75+
split_name=split_name,
76+
n_evolutions=n_evolutions,
77+
update_split=update_split,
78+
batch_size=batch_size,
79+
)
80+
)
81+
5482
original_split = dataset[split_name]
5583
new_samples = []
5684
for sample in original_split:
@@ -61,7 +89,43 @@ def augment(
6189
new_samples.extend(
6290
[{Dataset.label_feature: intent_data.id, Dataset.utterance_feature: ut} for ut in generated_utterances]
6391
)
92+
6493
if update_split:
6594
generated_split = HFDataset.from_list(new_samples)
6695
dataset[split_name] = concatenate_datasets([original_split, generated_split])
96+
97+
return [Sample(**sample) for sample in new_samples]
98+
99+
async def _augment_async(
100+
self,
101+
dataset: Dataset,
102+
split_name: str = Split.TRAIN,
103+
n_evolutions: int = 1,
104+
update_split: bool = True,
105+
batch_size: int = 4,
106+
) -> list[Sample]:
107+
original_split = dataset[split_name]
108+
new_samples = []
109+
110+
tasks = []
111+
labels = []
112+
for sample in original_split:
113+
utterance = sample[Dataset.utterance_feature]
114+
label = sample[Dataset.label_feature]
115+
intent_data = next(intent for intent in dataset.intents if intent.id == label)
116+
for _ in range(n_evolutions):
117+
tasks.append(self._evolve_async(utterance, intent_data))
118+
labels.append(intent_data.id)
119+
120+
for start_idx in range(0, len(tasks), batch_size):
121+
batch_tasks = tasks[start_idx : start_idx + batch_size]
122+
batch_labels = labels[start_idx : start_idx + batch_size]
123+
batch_results = await asyncio.gather(*batch_tasks)
124+
for result, intent_id in zip(batch_results, batch_labels, strict=False):
125+
new_samples.append({Dataset.label_feature: intent_id, Dataset.utterance_feature: result})
126+
127+
if update_split:
128+
generated_split = HFDataset.from_list(new_samples)
129+
dataset[split_name] = concatenate_datasets([original_split, generated_split])
130+
67131
return [Sample(**sample) for sample in new_samples]

autointent/generation/utterances/generator.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,18 @@ class Generator:
1414
def __init__(self) -> None:
1515
"""Initialize."""
1616
load_dotenv()
17-
self.client = openai.OpenAI(base_url=os.environ["OPENAI_BASE_URL"], api_key=os.environ["OPENAI_API_KEY"])
17+
self.client = openai.OpenAI(
18+
base_url=os.environ["OPENAI_BASE_URL"],
19+
api_key=os.environ["OPENAI_API_KEY"]
20+
)
21+
self.async_client = openai.AsyncOpenAI(
22+
base_url=os.environ["OPENAI_BASE_URL"],
23+
api_key=os.environ["OPENAI_API_KEY"]
24+
)
1825
self.model_name = os.environ["OPENAI_MODEL_NAME"]
1926

2027
def get_chat_completion(self, messages: list[Message]) -> str:
21-
"""Prompt LLM and return its answer."""
28+
"""Prompt LLM and return its answer synchronously."""
2229
response = self.client.chat.completions.create(
2330
messages=messages, # type: ignore[arg-type]
2431
model=self.model_name,
@@ -28,3 +35,15 @@ def get_chat_completion(self, messages: list[Message]) -> str:
2835
temperature=0.7,
2936
)
3037
return response.choices[0].message.content # type: ignore[return-value]
38+
39+
async def get_chat_completion_async(self, messages: list[Message]) -> str:
40+
"""Prompt LLM and return its answer asynchronously."""
41+
response = await self.async_client.chat.completions.create(
42+
messages=messages, # type: ignore[arg-type]
43+
model=self.model_name,
44+
max_tokens=150,
45+
n=1,
46+
stop=None,
47+
temperature=0.7,
48+
)
49+
return response.choices[0].message.content # type: ignore[return-value]

tests/generation/utterances/test_basic_synthesizer.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import Mock
1+
from unittest.mock import AsyncMock, Mock
22

33
from autointent.generation.utterances import SynthesizerChatTemplate, UtteranceGenerator
44

@@ -50,3 +50,49 @@ def test_on_dataset(dataset):
5050

5151
assert n_before + len(new_samples) == n_after
5252
assert len(new_samples) == len(dataset.intents)
53+
54+
55+
def test_on_dataset_async(dataset):
56+
mock_llm = AsyncMock()
57+
mock_llm.get_chat_completion_async.return_value = "1. LLM answer"
58+
59+
split_name = "train_0"
60+
61+
template = SynthesizerChatTemplate(dataset, split=split_name)
62+
augmenter = UtteranceGenerator(mock_llm, template, async_mode=True)
63+
64+
n_before = len(dataset[split_name])
65+
new_samples = augmenter.augment(dataset, split_name=split_name, update_split=False)
66+
n_after = len(dataset[split_name])
67+
68+
assert n_before == n_after
69+
assert len(new_samples) == len(dataset.intents)
70+
assert all(sample.utterance == "LLM answer" for sample in new_samples)
71+
72+
n_before = len(dataset[split_name])
73+
new_samples = augmenter.augment(dataset, split_name=split_name, update_split=True)
74+
n_after = len(dataset[split_name])
75+
76+
assert n_before + len(new_samples) == n_after
77+
assert len(new_samples) == len(dataset.intents)
78+
79+
def test_on_dataset_async_with_batch_size(dataset):
80+
mock_llm = AsyncMock()
81+
mock_llm.get_chat_completion_async.return_value = "1. LLM answer"
82+
83+
split_name = "train_0"
84+
85+
template = SynthesizerChatTemplate(dataset, split=split_name)
86+
augmenter = UtteranceGenerator(mock_llm, template, async_mode=True)
87+
88+
batch_size = 2
89+
new_samples = augmenter.augment(dataset, split_name=split_name, update_split=False, batch_size=batch_size)
90+
91+
assert len(new_samples) == len(dataset.intents)
92+
assert all(sample.utterance == "LLM answer" for sample in new_samples)
93+
94+
batch_size = len(dataset.intents) + 5
95+
new_samples = augmenter.augment(dataset, split_name=split_name, update_split=False, batch_size=batch_size)
96+
97+
assert len(new_samples) == len(dataset.intents)
98+
assert all(sample.utterance == "LLM answer" for sample in new_samples)

0 commit comments

Comments
 (0)