Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion autointent/generation/utterances/evolution/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def _parse_args() -> Namespace:
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--search-space", type=str, default=None)
parser.add_argument(
"--sequential",
action="store_true",
help=(
"Use sequential evolution. When this option is enabled, solutions "
"will evolve one after another, instead of using a parallel approach."
),
)

return parser.parse_args()

Expand All @@ -64,7 +72,11 @@ def main() -> None:
n_before = len(dataset[args.split])

new_samples = utterance_evolver.augment(
dataset, split_name=args.split, n_evolutions=args.n_evolutions, batch_size=args.batch_size
dataset,
split_name=args.split,
n_evolutions=args.n_evolutions,
batch_size=args.batch_size,
sequential=args.sequential,
)
n_after = len(dataset[args.split])

Expand Down
25 changes: 22 additions & 3 deletions autointent/generation/utterances/evolution/evolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,21 @@ async def _evolve_async(self, utterance: str, intent_data: Intent) -> str:
chat = maker(utterance, intent_data)
return await self.generator.get_chat_completion_async(chat)

def __call__(self, utterance: str, intent_data: Intent, n_evolutions: int = 1) -> list[str]:
def __call__(
self, utterance: str, intent_data: Intent, n_evolutions: int = 1, sequential: bool = False
) -> list[str]:
"""Apply evolutions multiple times (synchronously)."""
return [self._evolve(utterance, intent_data) for _ in range(n_evolutions)]
current_utterance = utterance
generated_utterances = []

for _ in range(n_evolutions):
gen_utt = self._evolve(current_utterance, intent_data)
generated_utterances.append(gen_utt)

if sequential:
current_utterance = gen_utt

return generated_utterances

def augment(
self,
Expand All @@ -62,13 +74,18 @@ def augment(
n_evolutions: int = 1,
update_split: bool = True,
batch_size: int = 4,
sequential: bool = False,
) -> HFDataset:
"""
Augment some split of dataset.

Note that for now it supports only single-label datasets.
"""
if self.async_mode:
if sequential:
error = "Sequential and async modes are not compatible"
raise ValueError(error)
Comment on lines 84 to +87
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

скорее not supported yet


return asyncio.run(
self._augment_async(
dataset=dataset,
Expand All @@ -85,7 +102,9 @@ def augment(
utterance = sample[Dataset.utterance_feature]
label = sample[Dataset.label_feature]
intent_data = next(intent for intent in dataset.intents if intent.id == label)
generated_utterances = self(utterance=utterance, intent_data=intent_data, n_evolutions=n_evolutions)
generated_utterances = self(
utterance=utterance, intent_data=intent_data, n_evolutions=n_evolutions, sequential=sequential
)
new_samples.extend(
[{Dataset.label_feature: intent_data.id, Dataset.utterance_feature: ut} for ut in generated_utterances]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def augment(
n_evolutions: int = 1,
update_split: bool = True,
batch_size: int = 4,
sequential: bool = False,
) -> HFDataset:
"""
Augment some split of dataset.
Expand All @@ -79,7 +80,12 @@ def augment(

for _ in range(n_evolutions):
new_samples_dataset = super().augment(
dataset, split_name=split_name, n_evolutions=1, update_split=False, batch_size=batch_size
dataset,
split_name=split_name,
n_evolutions=1,
update_split=False,
batch_size=batch_size,
sequential=sequential,
)
merge_dataset[split_name] = concatenate_datasets([merge_dataset[split_name], new_samples_dataset])
generated_samples.append(new_samples_dataset)
Expand Down
38 changes: 38 additions & 0 deletions tests/generation/utterances/test_evolver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from unittest.mock import AsyncMock, Mock

import pytest

from autointent.generation.utterances import AbstractEvolution, IncrementalUtteranceEvolver, UtteranceEvolver


Expand Down Expand Up @@ -28,6 +30,14 @@ def test_on_dataset_incremental(dataset):
assert len(new_samples) == n_before
assert set(new_samples.column_names) == set(dataset[split_name].column_names)

n_before = len(dataset[split_name])
new_samples = augmenter.augment(dataset, split_name=split_name, n_evolutions=1, update_split=True, sequential=True)
n_after = len(dataset[split_name])

assert n_before + len(new_samples) == n_after
assert len(new_samples) == n_before
assert set(new_samples.column_names) == set(dataset[split_name].column_names)


def test_on_dataset_increment_evolver_async(dataset):
mock_llm = AsyncMock()
Expand All @@ -54,6 +64,11 @@ def test_on_dataset_increment_evolver_async(dataset):
assert len(new_samples) == n_before
assert set(new_samples.column_names) == set(dataset[split_name].column_names)

with pytest.raises(ValueError, match="Sequential and async modes are not compatible"):
new_samples = augmenter.augment(
dataset, split_name=split_name, n_evolutions=1, update_split=True, sequential=True
)


def test_on_dataset_increment_evolver_async_with_batch_size(dataset):
mock_llm = AsyncMock()
Expand All @@ -80,6 +95,11 @@ def test_on_dataset_increment_evolver_async_with_batch_size(dataset):
assert len(new_samples) == len(dataset[split_name])
assert set(new_samples.column_names) == set(dataset[split_name].column_names)

with pytest.raises(ValueError, match="Sequential and async modes are not compatible"):
new_samples = augmenter.augment(
dataset, split_name=split_name, n_evolutions=1, update_split=True, batch_size=batch_size, sequential=True
)


def test_default_chat_template(dataset):
template = AbstractEvolution()
Expand Down Expand Up @@ -112,6 +132,14 @@ def test_on_dataset(dataset):
assert len(new_samples) == n_before
assert set(new_samples.column_names) == set(dataset[split_name].column_names)

n_before = len(dataset[split_name])
new_samples = augmenter.augment(dataset, split_name=split_name, n_evolutions=1, update_split=True, sequential=True)
n_after = len(dataset[split_name])

assert n_before + len(new_samples) == n_after
assert len(new_samples) == n_before
assert set(new_samples.column_names) == set(dataset[split_name].column_names)


def test_on_dataset_evolver_async(dataset):
mock_llm = AsyncMock()
Expand All @@ -138,6 +166,11 @@ def test_on_dataset_evolver_async(dataset):
assert len(new_samples) == n_before
assert set(new_samples.column_names) == set(dataset[split_name].column_names)

with pytest.raises(ValueError, match="Sequential and async modes are not compatible"):
new_samples = augmenter.augment(
dataset, split_name=split_name, n_evolutions=1, update_split=True, sequential=True
)


def test_on_dataset_evolver_async_with_batch_size(dataset):
mock_llm = AsyncMock()
Expand All @@ -163,3 +196,8 @@ def test_on_dataset_evolver_async_with_batch_size(dataset):

assert len(new_samples) == len(dataset[split_name])
assert set(new_samples.column_names) == set(dataset[split_name].column_names)

with pytest.raises(ValueError, match="Sequential and async modes are not compatible"):
new_samples = augmenter.augment(
dataset, split_name=split_name, n_evolutions=1, update_split=True, batch_size=batch_size, sequential=True
)