Skip to content

Commit 8273d0f

Browse files
committed
feat: added sequential
1 parent d2ac6e1 commit 8273d0f

File tree

3 files changed

+42
-5
lines changed

3 files changed

+42
-5
lines changed

autointent/generation/utterances/evolution/cli.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ def _parse_args() -> Namespace:
4545
parser.add_argument("--seed", type=int, default=0)
4646
parser.add_argument("--batch-size", type=int, default=4)
4747
parser.add_argument("--search-space", type=str, default=None)
48+
parser.add_argument(
49+
"--sequential",
50+
action="store_true",
51+
help=(
52+
"Use sequential evolution. When this option is enabled, solutions "
53+
"will evolve one after another, instead of using a parallel approach."
54+
),
55+
)
4856

4957
return parser.parse_args()
5058

@@ -64,7 +72,11 @@ def main() -> None:
6472
n_before = len(dataset[args.split])
6573

6674
new_samples = utterance_evolver.augment(
67-
dataset, split_name=args.split, n_evolutions=args.n_evolutions, batch_size=args.batch_size
75+
dataset,
76+
split_name=args.split,
77+
n_evolutions=args.n_evolutions,
78+
batch_size=args.batch_size,
79+
sequential=args.sequential,
6880
)
6981
n_after = len(dataset[args.split])
7082

autointent/generation/utterances/evolution/evolver.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,21 @@ async def _evolve_async(self, utterance: str, intent_data: Intent) -> str:
5151
chat = maker(utterance, intent_data)
5252
return await self.generator.get_chat_completion_async(chat)
5353

54-
def __call__(self, utterance: str, intent_data: Intent, n_evolutions: int = 1) -> list[str]:
54+
def __call__(
55+
self, utterance: str, intent_data: Intent, n_evolutions: int = 1, sequential: bool = False
56+
) -> list[str]:
5557
"""Apply evolutions multiple times (synchronously)."""
56-
return [self._evolve(utterance, intent_data) for _ in range(n_evolutions)]
58+
current_utterance = utterance
59+
generated_utterances = []
60+
61+
for _ in range(n_evolutions):
62+
gen_utt = self._evolve(current_utterance, intent_data)
63+
generated_utterances.append(gen_utt)
64+
65+
if sequential:
66+
current_utterance = gen_utt
67+
68+
return generated_utterances
5769

5870
def augment(
5971
self,
@@ -62,13 +74,18 @@ def augment(
6274
n_evolutions: int = 1,
6375
update_split: bool = True,
6476
batch_size: int = 4,
77+
sequential: bool = False,
6578
) -> HFDataset:
6679
"""
6780
Augment some split of dataset.
6881
6982
Note that for now it supports only single-label datasets.
7083
"""
7184
if self.async_mode:
85+
if sequential:
86+
error = "Sequential and async modes are not compatible"
87+
raise ValueError(error)
88+
7289
return asyncio.run(
7390
self._augment_async(
7491
dataset=dataset,
@@ -85,7 +102,9 @@ def augment(
85102
utterance = sample[Dataset.utterance_feature]
86103
label = sample[Dataset.label_feature]
87104
intent_data = next(intent for intent in dataset.intents if intent.id == label)
88-
generated_utterances = self(utterance=utterance, intent_data=intent_data, n_evolutions=n_evolutions)
105+
generated_utterances = self(
106+
utterance=utterance, intent_data=intent_data, n_evolutions=n_evolutions, sequential=sequential
107+
)
89108
new_samples.extend(
90109
[{Dataset.label_feature: intent_data.id, Dataset.utterance_feature: ut} for ut in generated_utterances]
91110
)

autointent/generation/utterances/evolution/incremental_evolver.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def augment(
6767
n_evolutions: int = 1,
6868
update_split: bool = True,
6969
batch_size: int = 4,
70+
sequential: bool = False,
7071
) -> HFDataset:
7172
"""
7273
Augment some split of dataset.
@@ -79,7 +80,12 @@ def augment(
7980

8081
for _ in range(n_evolutions):
8182
new_samples_dataset = super().augment(
82-
dataset, split_name=split_name, n_evolutions=1, update_split=False, batch_size=batch_size
83+
dataset,
84+
split_name=split_name,
85+
n_evolutions=1,
86+
update_split=False,
87+
batch_size=batch_size,
88+
sequential=sequential,
8389
)
8490
merge_dataset[split_name] = concatenate_datasets([merge_dataset[split_name], new_samples_dataset])
8591
generated_samples.append(new_samples_dataset)

0 commit comments

Comments
 (0)