Skip to content

Commit 46b10f6

Browse files
committed
refactor
1 parent ed36291 commit 46b10f6

File tree

1 file changed

+9
-8
lines changed
  • autointent/generation/utterances/evolution

1 file changed

+9
-8
lines changed

autointent/generation/utterances/evolution/cli.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
import logging
44
from argparse import ArgumentParser, Namespace
55
from pathlib import Path
6-
from typing import Any, Optional
6+
from typing import Any
77

88
from datasets import concatenate_datasets
9-
from datasets import Dataset as HFDataset
109

1110
from autointent import Dataset, Pipeline, load_dataset
1211
from autointent.configs import EmbedderConfig
@@ -70,7 +69,7 @@ def _optimize_n_evolutions(
7069
max_n_evolutions: int,
7170
split_train: str,
7271
batch_size: int,
73-
search_space: Optional[str],
72+
search_space: str | None,
7473
) -> Dataset:
7574
emb_config = EmbedderConfig(batch_size=16, device="cuda")
7675
search_space = _choose_search_space(search_space)
@@ -100,6 +99,7 @@ def _optimize_n_evolutions(
10099
logger.info("# optimal n evolutions: %s", best_n)
101100
return dataset
102101

102+
103103
def _parse_args() -> Namespace:
104104
parser = ArgumentParser()
105105
parser.add_argument(
@@ -161,7 +161,6 @@ def main() -> None:
161161
logger.warning("No evolutions selected. Exiting.")
162162
return
163163

164-
165164
generator = UtteranceEvolver(Generator(), evolutions, args.seed, args.async_mode)
166165
dataset = load_dataset(args.input_path)
167166

@@ -177,14 +176,16 @@ def main() -> None:
177176
)
178177
else:
179178
n_before = len(dataset[args.split])
180-
181-
new_samples = generator.augment(dataset, split_name=args.split, n_evolutions=args.n_evolutions, batch_size=args.batch_size)
179+
180+
new_samples = generator.augment(
181+
dataset, split_name=args.split, n_evolutions=args.n_evolutions, batch_size=args.batch_size
182+
)
182183
n_after = len(dataset[args.split])
183-
184+
184185
logger.info("# samples before %s", n_before)
185186
logger.info("# samples generated %s", len(new_samples))
186187
logger.info("# samples after %s", n_after)
187-
188+
188189
dataset.to_json(args.output_path)
189190

190191
if args.output_repo is not None:

0 commit comments

Comments
 (0)