Skip to content

Commit 865f0ed

Browse files
authored
Feat/sequential evolution (#149)
* feat: added sequential * feat: added tests
1 parent 54b8fad commit 865f0ed

File tree

4 files changed

+80
-5
lines changed

4 files changed

+80
-5
lines changed

autointent/generation/utterances/evolution/cli.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ def _parse_args() -> Namespace:
4545
parser.add_argument("--seed", type=int, default=0)
4646
parser.add_argument("--batch-size", type=int, default=4)
4747
parser.add_argument("--search-space", type=str, default=None)
48+
parser.add_argument(
49+
"--sequential",
50+
action="store_true",
51+
help=(
52+
"Use sequential evolution. When this option is enabled, solutions "
53+
"will evolve one after another, instead of using a parallel approach."
54+
),
55+
)
4856

4957
return parser.parse_args()
5058

@@ -64,7 +72,11 @@ def main() -> None:
6472
n_before = len(dataset[args.split])
6573

6674
new_samples = utterance_evolver.augment(
67-
dataset, split_name=args.split, n_evolutions=args.n_evolutions, batch_size=args.batch_size
75+
dataset,
76+
split_name=args.split,
77+
n_evolutions=args.n_evolutions,
78+
batch_size=args.batch_size,
79+
sequential=args.sequential,
6880
)
6981
n_after = len(dataset[args.split])
7082

autointent/generation/utterances/evolution/evolver.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,21 @@ async def _evolve_async(self, utterance: str, intent_data: Intent) -> str:
5151
chat = maker(utterance, intent_data)
5252
return await self.generator.get_chat_completion_async(chat)
5353

54-
def __call__(self, utterance: str, intent_data: Intent, n_evolutions: int = 1) -> list[str]:
54+
def __call__(
55+
self, utterance: str, intent_data: Intent, n_evolutions: int = 1, sequential: bool = False
56+
) -> list[str]:
5557
"""Apply evolutions multiple times (synchronously)."""
56-
return [self._evolve(utterance, intent_data) for _ in range(n_evolutions)]
58+
current_utterance = utterance
59+
generated_utterances = []
60+
61+
for _ in range(n_evolutions):
62+
gen_utt = self._evolve(current_utterance, intent_data)
63+
generated_utterances.append(gen_utt)
64+
65+
if sequential:
66+
current_utterance = gen_utt
67+
68+
return generated_utterances
5769

5870
def augment(
5971
self,
@@ -62,13 +74,18 @@ def augment(
6274
n_evolutions: int = 1,
6375
update_split: bool = True,
6476
batch_size: int = 4,
77+
sequential: bool = False,
6578
) -> HFDataset:
6679
"""
6780
Augment some split of dataset.
6881
6982
Note that for now it supports only single-label datasets.
7083
"""
7184
if self.async_mode:
85+
if sequential:
86+
error = "Sequential and async modes are not compatible"
87+
raise ValueError(error)
88+
7289
return asyncio.run(
7390
self._augment_async(
7491
dataset=dataset,
@@ -85,7 +102,9 @@ def augment(
85102
utterance = sample[Dataset.utterance_feature]
86103
label = sample[Dataset.label_feature]
87104
intent_data = next(intent for intent in dataset.intents if intent.id == label)
88-
generated_utterances = self(utterance=utterance, intent_data=intent_data, n_evolutions=n_evolutions)
105+
generated_utterances = self(
106+
utterance=utterance, intent_data=intent_data, n_evolutions=n_evolutions, sequential=sequential
107+
)
89108
new_samples.extend(
90109
[{Dataset.label_feature: intent_data.id, Dataset.utterance_feature: ut} for ut in generated_utterances]
91110
)

autointent/generation/utterances/evolution/incremental_evolver.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def augment(
6767
n_evolutions: int = 1,
6868
update_split: bool = True,
6969
batch_size: int = 4,
70+
sequential: bool = False,
7071
) -> HFDataset:
7172
"""
7273
Augment some split of dataset.
@@ -79,7 +80,12 @@ def augment(
7980

8081
for _ in range(n_evolutions):
8182
new_samples_dataset = super().augment(
82-
dataset, split_name=split_name, n_evolutions=1, update_split=False, batch_size=batch_size
83+
dataset,
84+
split_name=split_name,
85+
n_evolutions=1,
86+
update_split=False,
87+
batch_size=batch_size,
88+
sequential=sequential,
8389
)
8490
merge_dataset[split_name] = concatenate_datasets([merge_dataset[split_name], new_samples_dataset])
8591
generated_samples.append(new_samples_dataset)

tests/generation/utterances/test_evolver.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from unittest.mock import AsyncMock, Mock
22

3+
import pytest
4+
35
from autointent.generation.utterances import AbstractEvolution, IncrementalUtteranceEvolver, UtteranceEvolver
46

57

@@ -28,6 +30,14 @@ def test_on_dataset_incremental(dataset):
2830
assert len(new_samples) == n_before
2931
assert set(new_samples.column_names) == set(dataset[split_name].column_names)
3032

33+
n_before = len(dataset[split_name])
34+
new_samples = augmenter.augment(dataset, split_name=split_name, n_evolutions=1, update_split=True, sequential=True)
35+
n_after = len(dataset[split_name])
36+
37+
assert n_before + len(new_samples) == n_after
38+
assert len(new_samples) == n_before
39+
assert set(new_samples.column_names) == set(dataset[split_name].column_names)
40+
3141

3242
def test_on_dataset_increment_evolver_async(dataset):
3343
mock_llm = AsyncMock()
@@ -54,6 +64,11 @@ def test_on_dataset_increment_evolver_async(dataset):
5464
assert len(new_samples) == n_before
5565
assert set(new_samples.column_names) == set(dataset[split_name].column_names)
5666

67+
with pytest.raises(ValueError, match="Sequential and async modes are not compatible"):
68+
new_samples = augmenter.augment(
69+
dataset, split_name=split_name, n_evolutions=1, update_split=True, sequential=True
70+
)
71+
5772

5873
def test_on_dataset_increment_evolver_async_with_batch_size(dataset):
5974
mock_llm = AsyncMock()
@@ -80,6 +95,11 @@ def test_on_dataset_increment_evolver_async_with_batch_size(dataset):
8095
assert len(new_samples) == len(dataset[split_name])
8196
assert set(new_samples.column_names) == set(dataset[split_name].column_names)
8297

98+
with pytest.raises(ValueError, match="Sequential and async modes are not compatible"):
99+
new_samples = augmenter.augment(
100+
dataset, split_name=split_name, n_evolutions=1, update_split=True, batch_size=batch_size, sequential=True
101+
)
102+
83103

84104
def test_default_chat_template(dataset):
85105
template = AbstractEvolution()
@@ -112,6 +132,14 @@ def test_on_dataset(dataset):
112132
assert len(new_samples) == n_before
113133
assert set(new_samples.column_names) == set(dataset[split_name].column_names)
114134

135+
n_before = len(dataset[split_name])
136+
new_samples = augmenter.augment(dataset, split_name=split_name, n_evolutions=1, update_split=True, sequential=True)
137+
n_after = len(dataset[split_name])
138+
139+
assert n_before + len(new_samples) == n_after
140+
assert len(new_samples) == n_before
141+
assert set(new_samples.column_names) == set(dataset[split_name].column_names)
142+
115143

116144
def test_on_dataset_evolver_async(dataset):
117145
mock_llm = AsyncMock()
@@ -138,6 +166,11 @@ def test_on_dataset_evolver_async(dataset):
138166
assert len(new_samples) == n_before
139167
assert set(new_samples.column_names) == set(dataset[split_name].column_names)
140168

169+
with pytest.raises(ValueError, match="Sequential and async modes are not compatible"):
170+
new_samples = augmenter.augment(
171+
dataset, split_name=split_name, n_evolutions=1, update_split=True, sequential=True
172+
)
173+
141174

142175
def test_on_dataset_evolver_async_with_batch_size(dataset):
143176
mock_llm = AsyncMock()
@@ -163,3 +196,8 @@ def test_on_dataset_evolver_async_with_batch_size(dataset):
163196

164197
assert len(new_samples) == len(dataset[split_name])
165198
assert set(new_samples.column_names) == set(dataset[split_name].column_names)
199+
200+
with pytest.raises(ValueError, match="Sequential and async modes are not compatible"):
201+
new_samples = augmenter.augment(
202+
dataset, split_name=split_name, n_evolutions=1, update_split=True, batch_size=batch_size, sequential=True
203+
)

0 commit comments

Comments
 (0)