33import logging
44from argparse import ArgumentParser , Namespace
55from pathlib import Path
6- from typing import Any , Optional
6+ from typing import Any
77
88from datasets import concatenate_datasets
9- from datasets import Dataset as HFDataset
109
1110from autointent import Dataset , Pipeline , load_dataset
1211from autointent .configs import EmbedderConfig
@@ -70,7 +69,7 @@ def _optimize_n_evolutions(
7069 max_n_evolutions : int ,
7170 split_train : str ,
7271 batch_size : int ,
73- search_space : Optional [ str ] ,
72+ search_space : str | None ,
7473) -> Dataset :
7574 emb_config = EmbedderConfig (batch_size = 16 , device = "cuda" )
7675 search_space = _choose_search_space (search_space )
@@ -100,6 +99,7 @@ def _optimize_n_evolutions(
10099 logger .info ("# optimal n evolutions: %s" , best_n )
101100 return dataset
102101
102+
103103def _parse_args () -> Namespace :
104104 parser = ArgumentParser ()
105105 parser .add_argument (
@@ -161,7 +161,6 @@ def main() -> None:
161161 logger .warning ("No evolutions selected. Exiting." )
162162 return
163163
164-
165164 generator = UtteranceEvolver (Generator (), evolutions , args .seed , args .async_mode )
166165 dataset = load_dataset (args .input_path )
167166
@@ -177,14 +176,16 @@ def main() -> None:
177176 )
178177 else :
179178 n_before = len (dataset [args .split ])
180-
181- new_samples = generator .augment (dataset , split_name = args .split , n_evolutions = args .n_evolutions , batch_size = args .batch_size )
179+
180+ new_samples = generator .augment (
181+ dataset , split_name = args .split , n_evolutions = args .n_evolutions , batch_size = args .batch_size
182+ )
182183 n_after = len (dataset [args .split ])
183-
184+
184185 logger .info ("# samples before %s" , n_before )
185186 logger .info ("# samples generated %s" , len (new_samples ))
186187 logger .info ("# samples after %s" , n_after )
187-
188+
188189 dataset .to_json (args .output_path )
189190
190191 if args .output_repo is not None :
0 commit comments