Skip to content

Commit ed36291

Browse files
committed
refactor
1 parent 8fe33ab commit ed36291

File tree

1 file changed

+28
-47
lines changed
  • autointent/generation/utterances/evolution

1 file changed

+28
-47
lines changed

autointent/generation/utterances/evolution/cli.py

Lines changed: 28 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import logging
44
from argparse import ArgumentParser, Namespace
55
from pathlib import Path
6-
from typing import Any
6+
from typing import Any, Optional
77

88
from datasets import concatenate_datasets
9+
from datasets import Dataset as HFDataset
910

1011
from autointent import Dataset, Pipeline, load_dataset
1112
from autointent.configs import EmbedderConfig
@@ -63,24 +64,22 @@ def _choose_search_space(search_space: str | None) -> list[dict[str, Any]] | Pat
6364

6465

6566
def _optimize_n_evolutions(
67+
generator: Generator,
6668
input_path: str,
69+
dataset: Dataset,
6770
max_n_evolutions: int,
68-
evolutions: list,
69-
seed: int,
7071
split_train: str,
71-
async_mode: bool,
7272
batch_size: int,
73-
search_space: list[dict[str, Any]] | Path | str,
73+
search_space: Optional[str],
7474
) -> Dataset:
7575
emb_config = EmbedderConfig(batch_size=16, device="cuda")
76+
search_space = _choose_search_space(search_space)
7677

7778
best_result = 0
7879
best_n = 0
79-
dataset = load_dataset(input_path)
8080
merge_dataset = load_dataset(input_path)
8181

8282
for n in range(max_n_evolutions):
83-
generator = UtteranceEvolver(Generator(), evolutions, seed, async_mode)
8483
new_samples_dataset = generator.augment(
8584
dataset, split_name=split_train, n_evolutions=1, update_split=False, batch_size=batch_size
8685
)
@@ -101,32 +100,6 @@ def _optimize_n_evolutions(
101100
logger.info("# optimal n evolutions: %s", best_n)
102101
return dataset
103102

104-
105-
def _generate_fixed_evolutions(
106-
input_path: str,
107-
n_evolutions: int,
108-
evolutions: list,
109-
seed: int,
110-
split: str,
111-
async_mode: bool,
112-
batch_size: int,
113-
*args,
114-
**kwargs,
115-
) -> Dataset:
116-
dataset = load_dataset(input_path)
117-
n_before = len(dataset[split])
118-
119-
generator = UtteranceEvolver(Generator(), evolutions, seed, async_mode)
120-
new_samples = generator.augment(dataset, split_name=split, n_evolutions=n_evolutions, batch_size=batch_size)
121-
n_after = len(dataset[split])
122-
123-
logger.info("# samples before %s", n_before)
124-
logger.info("# samples generated %s", len(new_samples))
125-
logger.info("# samples after %s", n_after)
126-
127-
return dataset
128-
129-
130103
def _parse_args() -> Namespace:
131104
parser = ArgumentParser()
132105
parser.add_argument(
@@ -188,22 +161,30 @@ def main() -> None:
188161
logger.warning("No evolutions selected. Exiting.")
189162
return
190163

191-
search_space = _choose_search_space(args.search_space)
192164

193-
process_func = _generate_fixed_evolutions
165+
generator = UtteranceEvolver(Generator(), evolutions, args.seed, args.async_mode)
166+
dataset = load_dataset(args.input_path)
167+
194168
if args.decide_for_me:
195-
process_func = _optimize_n_evolutions
196-
197-
dataset = process_func(
198-
args.input_path,
199-
args.n_evolutions,
200-
evolutions,
201-
args.seed,
202-
args.split,
203-
args.async_mode,
204-
args.batch_size,
205-
search_space,
206-
)
169+
dataset = _optimize_n_evolutions(
170+
generator,
171+
args.input_path,
172+
dataset,
173+
args.n_evolutions,
174+
args.split,
175+
args.batch_size,
176+
args.search_space,
177+
)
178+
else:
179+
n_before = len(dataset[args.split])
180+
181+
new_samples = generator.augment(dataset, split_name=args.split, n_evolutions=args.n_evolutions, batch_size=args.batch_size)
182+
n_after = len(dataset[args.split])
183+
184+
logger.info("# samples before %s", n_before)
185+
logger.info("# samples generated %s", len(new_samples))
186+
logger.info("# samples after %s", n_after)
187+
207188
dataset.to_json(args.output_path)
208189

209190
if args.output_repo is not None:

0 commit comments

Comments
 (0)