diff --git a/autointent/generation/utterances/basic/cli.py b/autointent/generation/utterances/basic/cli.py index baa6968c8..b1d9d5264 100644 --- a/autointent/generation/utterances/basic/cli.py +++ b/autointent/generation/utterances/basic/cli.py @@ -14,7 +14,7 @@ def main() -> None: - """ClI endpoint.""" + """CLI endpoint.""" parser = ArgumentParser() parser.add_argument( "--input-path", @@ -48,11 +48,12 @@ def main() -> None: default=5, help="Number of utterances to use as an example for augmentation", ) + parser.add_argument("--async-mode", action="store_true", help="Enable asynchronous generation") args = parser.parse_args() dataset = load_dataset(args.input_path) template = SynthesizerChatTemplate(dataset, args.split, max_sample_utterances=args.n_sample_utterances) - generator = UtteranceGenerator(Generator(), template) + generator = UtteranceGenerator(Generator(), template, async_mode=args.async_mode) n_before = len(dataset[args.split]) new_samples = generator.augment(dataset, split_name=args.split, n_generations=args.n_generations) diff --git a/autointent/generation/utterances/basic/utterance_generator.py b/autointent/generation/utterances/basic/utterance_generator.py index 1d962272a..ac6189f8c 100644 --- a/autointent/generation/utterances/basic/utterance_generator.py +++ b/autointent/generation/utterances/basic/utterance_generator.py @@ -1,5 +1,6 @@ """Basic generation of new utterances from existing ones.""" +import asyncio from collections.abc import Callable from datasets import Dataset as HFDataset @@ -17,14 +18,17 @@ class UtteranceGenerator: Basic generation of new utterances from existing ones. This augmentation method simply prompts LLM to look at existing examples - and generate similar. Additionaly it can consider some aspects of style, - punctuation and length of the desired generations. + and generate similar. Additionally, it can consider some aspects of style, + punctuation, and length of the desired generations. """ - def __init__(self, generator: Generator, prompt_maker: Callable[[Intent, int], list[Message]]) -> None: + def __init__( + self, generator: Generator, prompt_maker: Callable[[Intent, int], list[Message]], async_mode: bool = False + ) -> None: """Initialize.""" self.generator = generator self.prompt_maker = prompt_maker + self.async_mode = async_mode def __call__(self, intent_data: Intent, n_generations: int) -> list[str]: """Generate new utterances.""" @@ -32,31 +36,85 @@ def __call__(self, intent_data: Intent, n_generations: int) -> list[str]: response_text = self.generator.get_chat_completion(messages) return _extract_utterances(response_text) + async def _call_async(self, intent_data: Intent, n_generations: int) -> list[str]: + """Generate new utterances asynchronously.""" + messages = self.prompt_maker(intent_data, n_generations) + response_text = await self.generator.get_chat_completion_async(messages) + return _extract_utterances(response_text) + def augment( self, dataset: Dataset, split_name: str = Split.TRAIN, n_generations: int = 5, update_split: bool = True, + batch_size: int = 4, ) -> list[Sample]: """ Augment some split of dataset. - TODO Note that for now it supports only single-label datasets. + :param dataset: Dataset object + :param split_name: Dataset split (default is TRAIN) + :param n_generations: Number of utterances to generate per intent + :param update_split: Whether to update the dataset split + :param batch_size: Batch size for async generation + :return: List of generated samples """ + if self.async_mode: + return asyncio.run(self._augment_async(dataset, split_name, n_generations, update_split, batch_size)) + original_split = dataset[split_name] new_samples = [] for intent in dataset.intents: - generated_utterances = self( - intent_data=intent, - n_generations=n_generations, + generated_utterances = self(intent_data=intent, n_generations=n_generations) + new_samples.extend( + [{Dataset.label_feature: intent.id, Dataset.utterance_feature: ut} for ut in generated_utterances] ) + + if update_split: + generated_split = HFDataset.from_list(new_samples) + dataset[split_name] = concatenate_datasets([original_split, generated_split]) + + return [Sample(**sample) for sample in new_samples] + + async def _augment_async( + self, + dataset: Dataset, + split_name: str = Split.TRAIN, + n_generations: int = 5, + update_split: bool = True, + batch_size: int = 4, + ) -> list[Sample]: + """ + Augment some split of dataset asynchronously in batches. + + :param dataset: Dataset object + :param split_name: Dataset split (default is TRAIN) + :param n_generations: Number of utterances to generate per intent + :param update_split: Whether to update the dataset split + :param batch_size: Batch size for async generation + :return: List of generated samples + """ + original_split = dataset[split_name] + new_samples = [] + + results = [] + for start_idx in range(0, len(dataset.intents), batch_size): + batch_intents = dataset.intents[start_idx : start_idx + batch_size] + tasks = [self._call_async(intent_data=intent, n_generations=n_generations) for intent in batch_intents] + batch_results = await asyncio.gather(*tasks) + results.extend(batch_results) + + for i, generated_utterances in enumerate(results): + intent = dataset.intents[i] new_samples.extend( [{Dataset.label_feature: intent.id, Dataset.utterance_feature: ut} for ut in generated_utterances] ) + if update_split: generated_split = HFDataset.from_list(new_samples) dataset[split_name] = concatenate_datasets([original_split, generated_split]) + return [Sample(**sample) for sample in new_samples] @@ -68,4 +126,4 @@ def _extract_utterances(response_text: str) -> list[str]: """ raw_utterances = response_text.split("\n") # remove enumeration - return [ut[ut.find(" ") + 1 :] for ut in raw_utterances] + return [ut[ut.find(" ") + 1 :] if " " in ut else ut for ut in raw_utterances] diff --git a/autointent/generation/utterances/evolution/cli.py b/autointent/generation/utterances/evolution/cli.py index fc271b45a..8438f739e 100644 --- a/autointent/generation/utterances/evolution/cli.py +++ b/autointent/generation/utterances/evolution/cli.py @@ -53,6 +53,7 @@ def main() -> None: parser.add_argument("--funny", action="store_true", help="Whether to use `Funny` evolution") parser.add_argument("--goofy", action="store_true", help="Whether to use `Goofy` evolution") parser.add_argument("--informal", action="store_true", help="Whether to use `Informal` evolution") + parser.add_argument("--async-mode", action="store_true", help="Enable asynchronous generation") parser.add_argument("--seed", type=int, default=0) args = parser.parse_args() @@ -80,7 +81,7 @@ def main() -> None: dataset = load_dataset(args.input_path) n_before = len(dataset[args.split]) - generator = UtteranceEvolver(Generator(), evolutions, args.seed) + generator = UtteranceEvolver(Generator(), evolutions, args.seed, async_mode=args.async_mode) new_samples = generator.augment(dataset, split_name=args.split, n_evolutions=args.n_evolutions) n_after = len(dataset[args.split]) diff --git a/autointent/generation/utterances/evolution/evolver.py b/autointent/generation/utterances/evolution/evolver.py index bdcf10692..9183fd0a0 100644 --- a/autointent/generation/utterances/evolution/evolver.py +++ b/autointent/generation/utterances/evolution/evolver.py @@ -4,6 +4,7 @@ Deeply inspired by DeepEval evolutions. """ +import asyncio import random from collections.abc import Callable, Sequence @@ -26,31 +27,58 @@ class UtteranceEvolver: """ def __init__( - self, generator: Generator, prompt_makers: Sequence[Callable[[str, Intent], list[Message]]], seed: int = 0 + self, + generator: Generator, + prompt_makers: Sequence[Callable[[str, Intent], list[Message]]], + seed: int = 0, + async_mode: bool = False, ) -> None: """Initialize.""" self.generator = generator self.prompt_makers = prompt_makers + self.async_mode = async_mode random.seed(seed) def _evolve(self, utterance: str, intent_data: Intent) -> str: - """Apply evolutions single time.""" + """Apply evolutions single time synchronously.""" maker = random.choice(self.prompt_makers) chat = maker(utterance, intent_data) return self.generator.get_chat_completion(chat) + async def _evolve_async(self, utterance: str, intent_data: Intent) -> str: + """Apply evolutions a single time (asynchronously).""" + maker = random.choice(self.prompt_makers) + chat = maker(utterance, intent_data) + return await self.generator.get_chat_completion_async(chat) + def __call__(self, utterance: str, intent_data: Intent, n_evolutions: int = 1) -> list[str]: - """Apply evolutions mupltiple times.""" + """Apply evolutions multiple times (synchronously).""" return [self._evolve(utterance, intent_data) for _ in range(n_evolutions)] def augment( - self, dataset: Dataset, split_name: str = Split.TRAIN, n_evolutions: int = 1, update_split: bool = True + self, + dataset: Dataset, + split_name: str = Split.TRAIN, + n_evolutions: int = 1, + update_split: bool = True, + batch_size: int = 4, ) -> list[Sample]: """ Augment some split of dataset. Note that for now it supports only single-label datasets. """ + if self.async_mode: + return asyncio.run( + self._augment_async( + dataset=dataset, + split_name=split_name, + n_evolutions=n_evolutions, + update_split=update_split, + batch_size=batch_size, + ) + ) + original_split = dataset[split_name] new_samples = [] for sample in original_split: @@ -61,7 +89,43 @@ def augment( new_samples.extend( [{Dataset.label_feature: intent_data.id, Dataset.utterance_feature: ut} for ut in generated_utterances] ) + if update_split: generated_split = HFDataset.from_list(new_samples) dataset[split_name] = concatenate_datasets([original_split, generated_split]) + + return [Sample(**sample) for sample in new_samples] + + async def _augment_async( + self, + dataset: Dataset, + split_name: str = Split.TRAIN, + n_evolutions: int = 1, + update_split: bool = True, + batch_size: int = 4, + ) -> list[Sample]: + original_split = dataset[split_name] + new_samples = [] + + tasks = [] + labels = [] + for sample in original_split: + utterance = sample[Dataset.utterance_feature] + label = sample[Dataset.label_feature] + intent_data = next(intent for intent in dataset.intents if intent.id == label) + for _ in range(n_evolutions): + tasks.append(self._evolve_async(utterance, intent_data)) + labels.append(intent_data.id) + + for start_idx in range(0, len(tasks), batch_size): + batch_tasks = tasks[start_idx : start_idx + batch_size] + batch_labels = labels[start_idx : start_idx + batch_size] + batch_results = await asyncio.gather(*batch_tasks) + for result, intent_id in zip(batch_results, batch_labels, strict=False): + new_samples.append({Dataset.label_feature: intent_id, Dataset.utterance_feature: result}) + + if update_split: + generated_split = HFDataset.from_list(new_samples) + dataset[split_name] = concatenate_datasets([original_split, generated_split]) + return [Sample(**sample) for sample in new_samples] diff --git a/autointent/generation/utterances/generator.py b/autointent/generation/utterances/generator.py index dfc9ec869..d8d6a094e 100644 --- a/autointent/generation/utterances/generator.py +++ b/autointent/generation/utterances/generator.py @@ -14,11 +14,18 @@ class Generator: def __init__(self) -> None: """Initialize.""" load_dotenv() - self.client = openai.OpenAI(base_url=os.environ["OPENAI_BASE_URL"], api_key=os.environ["OPENAI_API_KEY"]) + self.client = openai.OpenAI( + base_url=os.environ["OPENAI_BASE_URL"], + api_key=os.environ["OPENAI_API_KEY"] + ) + self.async_client = openai.AsyncOpenAI( + base_url=os.environ["OPENAI_BASE_URL"], + api_key=os.environ["OPENAI_API_KEY"] + ) self.model_name = os.environ["OPENAI_MODEL_NAME"] def get_chat_completion(self, messages: list[Message]) -> str: - """Prompt LLM and return its answer.""" + """Prompt LLM and return its answer synchronously.""" response = self.client.chat.completions.create( messages=messages, # type: ignore[arg-type] model=self.model_name, @@ -28,3 +35,15 @@ def get_chat_completion(self, messages: list[Message]) -> str: temperature=0.7, ) return response.choices[0].message.content # type: ignore[return-value] + + async def get_chat_completion_async(self, messages: list[Message]) -> str: + """Prompt LLM and return its answer asynchronously.""" + response = await self.async_client.chat.completions.create( + messages=messages, # type: ignore[arg-type] + model=self.model_name, + max_tokens=150, + n=1, + stop=None, + temperature=0.7, + ) + return response.choices[0].message.content # type: ignore[return-value] diff --git a/tests/generation/utterances/test_basic_synthesizer.py b/tests/generation/utterances/test_basic_synthesizer.py index 674b3535f..4047ce6fc 100644 --- a/tests/generation/utterances/test_basic_synthesizer.py +++ b/tests/generation/utterances/test_basic_synthesizer.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from autointent.generation.utterances import SynthesizerChatTemplate, UtteranceGenerator @@ -50,3 +50,49 @@ def test_on_dataset(dataset): assert n_before + len(new_samples) == n_after assert len(new_samples) == len(dataset.intents) + + +def test_on_dataset_async(dataset): + mock_llm = AsyncMock() + mock_llm.get_chat_completion_async.return_value = "1. LLM answer" + + split_name = "train_0" + + template = SynthesizerChatTemplate(dataset, split=split_name) + augmenter = UtteranceGenerator(mock_llm, template, async_mode=True) + + n_before = len(dataset[split_name]) + new_samples = augmenter.augment(dataset, split_name=split_name, update_split=False) + n_after = len(dataset[split_name]) + + assert n_before == n_after + assert len(new_samples) == len(dataset.intents) + assert all(sample.utterance == "LLM answer" for sample in new_samples) + + n_before = len(dataset[split_name]) + new_samples = augmenter.augment(dataset, split_name=split_name, update_split=True) + n_after = len(dataset[split_name]) + + assert n_before + len(new_samples) == n_after + assert len(new_samples) == len(dataset.intents) + +def test_on_dataset_async_with_batch_size(dataset): + mock_llm = AsyncMock() + mock_llm.get_chat_completion_async.return_value = "1. LLM answer" + + split_name = "train_0" + + template = SynthesizerChatTemplate(dataset, split=split_name) + augmenter = UtteranceGenerator(mock_llm, template, async_mode=True) + + batch_size = 2 + new_samples = augmenter.augment(dataset, split_name=split_name, update_split=False, batch_size=batch_size) + + assert len(new_samples) == len(dataset.intents) + assert all(sample.utterance == "LLM answer" for sample in new_samples) + + batch_size = len(dataset.intents) + 5 + new_samples = augmenter.augment(dataset, split_name=split_name, update_split=False, batch_size=batch_size) + + assert len(new_samples) == len(dataset.intents) + assert all(sample.utterance == "LLM answer" for sample in new_samples) diff --git a/tests/generation/utterances/test_evolver.py b/tests/generation/utterances/test_evolver.py index c95c9defa..250a4e86a 100644 --- a/tests/generation/utterances/test_evolver.py +++ b/tests/generation/utterances/test_evolver.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from autointent.generation.utterances import AbstractEvolution, UtteranceEvolver @@ -32,3 +32,54 @@ def test_on_dataset(dataset): assert n_before + len(new_samples) == n_after assert len(new_samples) == n_before + + +def test_on_dataset_evolver_async(dataset): + mock_llm = AsyncMock() + mock_llm.get_chat_completion_async.return_value = "LLM answer" + + split_name = "train_0" + + template = AbstractEvolution() + augmenter = UtteranceEvolver(mock_llm, [template], async_mode=True) + + n_before = len(dataset[split_name]) + new_samples = augmenter.augment(dataset, split_name=split_name, n_evolutions=1, update_split=False) + n_after = len(dataset[split_name]) + + assert n_before == n_after + assert len(new_samples) == n_before + assert all(sample.utterance == "LLM answer" for sample in new_samples) + + n_before = len(dataset[split_name]) + new_samples = augmenter.augment(dataset, split_name=split_name, n_evolutions=1, update_split=True) + n_after = len(dataset[split_name]) + + assert n_before + len(new_samples) == n_after + assert len(new_samples) == n_before + + +def test_on_dataset_evolver_async_with_batch_size(dataset): + mock_llm = AsyncMock() + mock_llm.get_chat_completion_async.return_value = "LLM answer" + + split_name = "train_0" + + template = AbstractEvolution() + augmenter = UtteranceEvolver(mock_llm, [template], async_mode=True) + + batch_size = 2 + new_samples = augmenter.augment( + dataset, split_name=split_name, n_evolutions=1, update_split=False, batch_size=batch_size + ) + + assert len(new_samples) == len(dataset[split_name]) + assert all(sample.utterance == "LLM answer" for sample in new_samples) + + batch_size = len(dataset[split_name]) + 5 + new_samples = augmenter.augment( + dataset, split_name=split_name, n_evolutions=1, update_split=False, batch_size=batch_size + ) + + assert len(new_samples) == len(dataset[split_name]) + assert all(sample.utterance == "LLM answer" for sample in new_samples) diff --git a/tests/generation/utterances/test_generator.py b/tests/generation/utterances/test_generator.py index 667f3dc9c..ad1930b9b 100644 --- a/tests/generation/utterances/test_generator.py +++ b/tests/generation/utterances/test_generator.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -36,3 +36,30 @@ def test_get_chat_completion(mock_openai_client): assert response == "Test response" mock_openai_client.return_value.chat.completions.create.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_chat_completion_async(): + test_messages = [Message(role="user", content="Hello, how are you?")] + + mock_response = MagicMock() + mock_response.choices = [MagicMock(message=MagicMock(content="I'm fine, thank you!"))] + + with patch("openai.AsyncOpenAI") as mock_async_openai: + mock_instance = mock_async_openai.return_value + mock_instance.chat.completions.create = AsyncMock(return_value=mock_response) + + generator = Generator() + + result = await generator.get_chat_completion_async(test_messages) + + assert result == "I'm fine, thank you!" + + mock_instance.chat.completions.create.assert_awaited_once_with( + messages=test_messages, + model=generator.model_name, + max_tokens=150, + n=1, + stop=None, + temperature=0.7, + )