Skip to content

Commit d1e4716

Browse files
authored
feat: added optimizer n evolutions (#126)
* feat: added optimizer n evolutions * refactor: effective way to generate evolutions * refactor: effective way to generate evolutions * feat: added search space * refactor * refactor * feat: added tests and fix mypy * feat: IncrementalUtteranceEvolver * feat: deleted embedder config * fix: mypy * fix: import * feat: updated cli and search space * feat: updated tests * feat: updated __init__ * fix: mypy
1 parent 71d714f commit d1e4716

File tree

6 files changed

+232
-36
lines changed

6 files changed

+232
-36
lines changed

autointent/generation/utterances/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
FormalEvolution,
77
FunnyEvolution,
88
GoofyEvolution,
9+
IncrementalUtteranceEvolver,
910
InformalEvolution,
1011
ReasoningEvolution,
1112
UtteranceEvolver,
@@ -20,6 +21,7 @@
2021
"FunnyEvolution",
2122
"Generator",
2223
"GoofyEvolution",
24+
"IncrementalUtteranceEvolver",
2325
"InformalEvolution",
2426
"ReasoningEvolution",
2527
"SynthesizerChatTemplate",

autointent/generation/utterances/evolution/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
ReasoningEvolution,
1010
)
1111
from .evolver import UtteranceEvolver
12+
from .incremental_evolver import IncrementalUtteranceEvolver
1213

1314
__all__ = [
1415
"AbstractEvolution",
@@ -17,6 +18,7 @@
1718
"FormalEvolution",
1819
"FunnyEvolution",
1920
"GoofyEvolution",
21+
"IncrementalUtteranceEvolver",
2022
"InformalEvolution",
2123
"ReasoningEvolution",
2224
"UtteranceEvolver",

autointent/generation/utterances/evolution/cli.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
"""CLI for evolutionary augmenter."""
22

33
import logging
4-
from argparse import ArgumentParser
4+
from argparse import ArgumentParser, Namespace
55

66
from autointent import load_dataset
7-
from autointent.generation.utterances.evolution.evolver import UtteranceEvolver
7+
from autointent.generation.utterances.evolution import IncrementalUtteranceEvolver, UtteranceEvolver
88
from autointent.generation.utterances.generator import Generator
99

1010
from .chat_templates import (
1111
AbstractEvolution,
1212
ConcreteEvolution,
13-
EvolutionChatTemplate,
1413
FormalEvolution,
1514
FunnyEvolution,
1615
GoofyEvolution,
@@ -22,8 +21,7 @@
2221
logger = logging.getLogger(__name__)
2322

2423

25-
def main() -> None:
26-
"""CLI endpoint."""
24+
def _parse_args() -> Namespace:
2725
parser = ArgumentParser()
2826
parser.add_argument(
2927
"--input-path",
@@ -46,6 +44,7 @@ def main() -> None:
4644
)
4745
parser.add_argument("--private", action="store_true", help="Publish privately if --output-repo option is used")
4846
parser.add_argument("--n-evolutions", type=int, default=1, help="Number of utterances to generate for each intent")
47+
parser.add_argument("--decide-for-me", action="store_true")
4948
parser.add_argument("--reasoning", action="store_true", help="Whether to use `Reasoning` evolution")
5049
parser.add_argument("--concretizing", action="store_true", help="Whether to use `Concretizing` evolution")
5150
parser.add_argument("--abstract", action="store_true", help="Whether to use `Abstract` evolution")
@@ -55,34 +54,46 @@ def main() -> None:
5554
parser.add_argument("--informal", action="store_true", help="Whether to use `Informal` evolution")
5655
parser.add_argument("--async-mode", action="store_true", help="Enable asynchronous generation")
5756
parser.add_argument("--seed", type=int, default=0)
57+
parser.add_argument("--batch-size", type=int, default=4)
58+
parser.add_argument("--search-space", type=str, default=None)
59+
60+
return parser.parse_args()
61+
5862

59-
args = parser.parse_args()
60-
61-
evolutions: list[EvolutionChatTemplate] = []
62-
if args.reasoning:
63-
evolutions.append(ReasoningEvolution())
64-
if args.concretizing:
65-
evolutions.append(ConcreteEvolution())
66-
if args.abstract:
67-
evolutions.append(AbstractEvolution())
68-
if args.formal:
69-
evolutions.append(FormalEvolution())
70-
if args.funny:
71-
evolutions.append(FunnyEvolution())
72-
if args.goofy:
73-
evolutions.append(GoofyEvolution())
74-
if args.informal:
75-
evolutions.append(InformalEvolution())
63+
def main() -> None:
64+
"""CLI endpoint."""
65+
mapping = {
66+
"reasoning": ReasoningEvolution,
67+
"concretizing": ConcreteEvolution,
68+
"abstract": AbstractEvolution,
69+
"formal": FormalEvolution,
70+
"funny": FunnyEvolution,
71+
"goofy": GoofyEvolution,
72+
"informal": InformalEvolution,
73+
}
74+
args = _parse_args()
75+
evolutions = []
76+
77+
for arg_name, evolution_cls in mapping.items():
78+
if getattr(args, arg_name):
79+
evolutions.append(evolution_cls()) # type: ignore[abstract]
7680

7781
if not evolutions:
7882
logger.warning("No evolutions selected. Exiting.")
7983
return
8084

85+
utterance_evolver: UtteranceEvolver
86+
if args.decide_for_me:
87+
utterance_evolver = IncrementalUtteranceEvolver(Generator(), evolutions, args.seed, args.async_mode)
88+
else:
89+
utterance_evolver = UtteranceEvolver(Generator(), evolutions, args.seed, args.async_mode)
8190
dataset = load_dataset(args.input_path)
91+
8292
n_before = len(dataset[args.split])
8393

84-
generator = UtteranceEvolver(Generator(), evolutions, args.seed, async_mode=args.async_mode)
85-
new_samples = generator.augment(dataset, split_name=args.split, n_evolutions=args.n_evolutions)
94+
new_samples = utterance_evolver.augment(
95+
dataset, split_name=args.split, n_evolutions=args.n_evolutions, batch_size=args.batch_size
96+
)
8697
n_after = len(dataset[args.split])
8798

8899
logger.info("# samples before %s", n_before)

autointent/generation/utterances/evolution/evolver.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from autointent.custom_types import Split
1616
from autointent.generation.utterances.generator import Generator
1717
from autointent.generation.utterances.schemas import Message
18-
from autointent.schemas import Intent, Sample
18+
from autointent.schemas import Intent
1919

2020

2121
class UtteranceEvolver:
@@ -62,7 +62,7 @@ def augment(
6262
n_evolutions: int = 1,
6363
update_split: bool = True,
6464
batch_size: int = 4,
65-
) -> list[Sample]:
65+
) -> HFDataset:
6666
"""
6767
Augment some split of dataset.
6868
@@ -90,11 +90,11 @@ def augment(
9090
[{Dataset.label_feature: intent_data.id, Dataset.utterance_feature: ut} for ut in generated_utterances]
9191
)
9292

93+
generated_split = HFDataset.from_list(new_samples)
9394
if update_split:
94-
generated_split = HFDataset.from_list(new_samples)
9595
dataset[split_name] = concatenate_datasets([original_split, generated_split])
9696

97-
return [Sample(**sample) for sample in new_samples]
97+
return generated_split
9898

9999
async def _augment_async(
100100
self,
@@ -103,7 +103,7 @@ async def _augment_async(
103103
n_evolutions: int = 1,
104104
update_split: bool = True,
105105
batch_size: int = 4,
106-
) -> list[Sample]:
106+
) -> HFDataset:
107107
original_split = dataset[split_name]
108108
new_samples = []
109109

@@ -124,8 +124,8 @@ async def _augment_async(
124124
for result, intent_id in zip(batch_results, batch_labels, strict=False):
125125
new_samples.append({Dataset.label_feature: intent_id, Dataset.utterance_feature: result})
126126

127+
generated_split = HFDataset.from_list(new_samples)
127128
if update_split:
128-
generated_split = HFDataset.from_list(new_samples)
129129
dataset[split_name] = concatenate_datasets([original_split, generated_split])
130130

131-
return [Sample(**sample) for sample in new_samples]
131+
return generated_split
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""
2+
Evolutionary strategy to augmenting utterances.
3+
4+
Deeply inspired by DeepEval evolutions.
5+
"""
6+
7+
import copy
8+
from collections.abc import Callable, Sequence
9+
from pathlib import Path
10+
from typing import Any
11+
12+
from datasets import Dataset as HFDataset
13+
from datasets import concatenate_datasets
14+
15+
from autointent import Dataset, Pipeline
16+
from autointent.custom_types import Split
17+
from autointent.generation.utterances.evolution.evolver import UtteranceEvolver
18+
from autointent.generation.utterances.generator import Generator
19+
from autointent.generation.utterances.schemas import Message
20+
from autointent.schemas import Intent
21+
22+
SEARCH_SPACE = [
23+
{
24+
"node_type": "scoring",
25+
"target_metric": "scoring_roc_auc",
26+
"metrics": ["scoring_accuracy"],
27+
"search_space": [
28+
{
29+
"module_name": "linear",
30+
"embedder_config": ["sentence-transformers/all-MiniLM-L6-v2"],
31+
}
32+
],
33+
},
34+
{
35+
"node_type": "decision",
36+
"target_metric": "decision_accuracy",
37+
"search_space": [
38+
{"module_name": "argmax"},
39+
],
40+
},
41+
]
42+
43+
44+
class IncrementalUtteranceEvolver(UtteranceEvolver):
45+
"""Incremental evolutionary strategy to augmenting utterances."""
46+
47+
def __init__(
48+
self,
49+
generator: Generator,
50+
prompt_makers: Sequence[Callable[[str, Intent], list[Message]]],
51+
seed: int = 0,
52+
async_mode: bool = False,
53+
search_space: str | None = None,
54+
) -> None:
55+
"""Initialize."""
56+
super().__init__(generator, prompt_makers, seed, async_mode)
57+
self.search_space = self._choose_search_space(search_space)
58+
59+
def _choose_search_space(self, search_space: str | None) -> list[dict[str, Any]] | Path | str:
60+
if search_space is None:
61+
return SEARCH_SPACE
62+
return search_space
63+
64+
def augment(
65+
self,
66+
dataset: Dataset,
67+
split_name: str = Split.TRAIN,
68+
n_evolutions: int = 1,
69+
update_split: bool = True,
70+
batch_size: int = 4,
71+
) -> HFDataset:
72+
"""
73+
Augment some split of dataset.
74+
75+
Note that for now it supports only single-label datasets.
76+
"""
77+
best_result = 0
78+
merge_dataset = copy.deepcopy(dataset)
79+
generated_samples = []
80+
81+
for _ in range(n_evolutions):
82+
new_samples_dataset = super().augment(
83+
dataset, split_name=split_name, n_evolutions=1, update_split=False, batch_size=batch_size
84+
)
85+
merge_dataset[split_name] = concatenate_datasets([merge_dataset[split_name], new_samples_dataset])
86+
generated_samples.append(new_samples_dataset)
87+
88+
pipeline_optimizer = Pipeline.from_search_space(self.search_space)
89+
ctx = pipeline_optimizer.fit(merge_dataset)
90+
results = ctx.optimization_info.dump_evaluation_results()
91+
decision_metric = results["metrics"]["decision"][0]
92+
93+
if decision_metric > best_result:
94+
best_result = decision_metric
95+
else:
96+
break
97+
98+
if update_split:
99+
dataset[split_name] = merge_dataset[split_name]
100+
101+
return concatenate_datasets(generated_samples)

0 commit comments

Comments
 (0)