33import logging
44from argparse import ArgumentParser , Namespace
55from pathlib import Path
6- from typing import Any
6+ from typing import Any , Optional
77
88from datasets import concatenate_datasets
9+ from datasets import Dataset as HFDataset
910
1011from autointent import Dataset , Pipeline , load_dataset
1112from autointent .configs import EmbedderConfig
@@ -63,24 +64,22 @@ def _choose_search_space(search_space: str | None) -> list[dict[str, Any]] | Pat
6364
6465
6566def _optimize_n_evolutions (
67+ generator : Generator ,
6668 input_path : str ,
69+ dataset : Dataset ,
6770 max_n_evolutions : int ,
68- evolutions : list ,
69- seed : int ,
7071 split_train : str ,
71- async_mode : bool ,
7272 batch_size : int ,
73- search_space : list [ dict [ str , Any ]] | Path | str ,
73+ search_space : Optional [ str ] ,
7474) -> Dataset :
7575 emb_config = EmbedderConfig (batch_size = 16 , device = "cuda" )
76+ search_space = _choose_search_space (search_space )
7677
7778 best_result = 0
7879 best_n = 0
79- dataset = load_dataset (input_path )
8080 merge_dataset = load_dataset (input_path )
8181
8282 for n in range (max_n_evolutions ):
83- generator = UtteranceEvolver (Generator (), evolutions , seed , async_mode )
8483 new_samples_dataset = generator .augment (
8584 dataset , split_name = split_train , n_evolutions = 1 , update_split = False , batch_size = batch_size
8685 )
@@ -101,32 +100,6 @@ def _optimize_n_evolutions(
101100 logger .info ("# optimal n evolutions: %s" , best_n )
102101 return dataset
103102
104-
105- def _generate_fixed_evolutions (
106- input_path : str ,
107- n_evolutions : int ,
108- evolutions : list ,
109- seed : int ,
110- split : str ,
111- async_mode : bool ,
112- batch_size : int ,
113- * args ,
114- ** kwargs ,
115- ) -> Dataset :
116- dataset = load_dataset (input_path )
117- n_before = len (dataset [split ])
118-
119- generator = UtteranceEvolver (Generator (), evolutions , seed , async_mode )
120- new_samples = generator .augment (dataset , split_name = split , n_evolutions = n_evolutions , batch_size = batch_size )
121- n_after = len (dataset [split ])
122-
123- logger .info ("# samples before %s" , n_before )
124- logger .info ("# samples generated %s" , len (new_samples ))
125- logger .info ("# samples after %s" , n_after )
126-
127- return dataset
128-
129-
130103def _parse_args () -> Namespace :
131104 parser = ArgumentParser ()
132105 parser .add_argument (
@@ -188,22 +161,30 @@ def main() -> None:
188161 logger .warning ("No evolutions selected. Exiting." )
189162 return
190163
191- search_space = _choose_search_space (args .search_space )
192164
193- process_func = _generate_fixed_evolutions
165+ generator = UtteranceEvolver (Generator (), evolutions , args .seed , args .async_mode )
166+ dataset = load_dataset (args .input_path )
167+
194168 if args .decide_for_me :
195- process_func = _optimize_n_evolutions
196-
197- dataset = process_func (
198- args .input_path ,
199- args .n_evolutions ,
200- evolutions ,
201- args .seed ,
202- args .split ,
203- args .async_mode ,
204- args .batch_size ,
205- search_space ,
206- )
169+ dataset = _optimize_n_evolutions (
170+ generator ,
171+ args .input_path ,
172+ dataset ,
173+ args .n_evolutions ,
174+ args .split ,
175+ args .batch_size ,
176+ args .search_space ,
177+ )
178+ else :
179+ n_before = len (dataset [args .split ])
180+
181+ new_samples = generator .augment (dataset , split_name = args .split , n_evolutions = args .n_evolutions , batch_size = args .batch_size )
182+ n_after = len (dataset [args .split ])
183+
184+ logger .info ("# samples before %s" , n_before )
185+ logger .info ("# samples generated %s" , len (new_samples ))
186+ logger .info ("# samples after %s" , n_after )
187+
207188 dataset .to_json (args .output_path )
208189
209190 if args .output_repo is not None :
0 commit comments