22
33import logging
44from argparse import ArgumentParser , Namespace
5+ from pathlib import Path
6+ from typing import Any
57
68from datasets import concatenate_datasets
79
2123 ReasoningEvolution ,
2224)
2325
24- # logging.basicConfig(level="INFO")
26+ logging .basicConfig (level = "INFO" )
2527logger = logging .getLogger (__name__ )
2628
2729SEARCH_SPACE = [
5456]
5557
5658
59+ def _choose_search_space (search_space : str | None ) -> list [dict [str , Any ]] | Path | str :
60+ if search_space is None :
61+ return SEARCH_SPACE
62+ return search_space
63+
64+
5765def _optimize_n_evolutions (
5866 input_path : str ,
5967 max_n_evolutions : int ,
@@ -62,14 +70,14 @@ def _optimize_n_evolutions(
6270 split_train : str ,
6371 async_mode : bool ,
6472 batch_size : int ,
65- ) -> tuple [Dataset , int ]:
73+ search_space : list [dict [str , Any ]] | Path | str ,
74+ ) -> Dataset :
6675 emb_config = EmbedderConfig (batch_size = 16 , device = "cuda" )
6776
6877 best_result = 0
6978 best_n = 0
7079 dataset = load_dataset (input_path )
7180 merge_dataset = load_dataset (input_path )
72- k = 0.9
7381
7482 for n in range (max_n_evolutions ):
7583 generator = UtteranceEvolver (Generator (), evolutions , seed , async_mode )
@@ -78,12 +86,11 @@ def _optimize_n_evolutions(
7886 )
7987 merge_dataset [split_train ] = concatenate_datasets ([merge_dataset [split_train ], new_samples_dataset ])
8088
81- pipeline_optimizer = Pipeline .from_search_space (SEARCH_SPACE )
89+ pipeline_optimizer = Pipeline .from_search_space (search_space )
8290 pipeline_optimizer .set_config (emb_config )
8391 ctx = pipeline_optimizer .fit (merge_dataset )
8492 results = ctx .optimization_info .dump_evaluation_results ()
85- decision_metric = results ["metrics" ]["decision" ][0 ] - k
86- k -= 0.1
93+ decision_metric = results ["metrics" ]["decision" ][0 ]
8794
8895 if decision_metric > best_result :
8996 best_result = decision_metric
@@ -96,7 +103,15 @@ def _optimize_n_evolutions(
96103
97104
98105def _generate_fixed_evolutions (
99- input_path : str , n_evolutions : int , evolutions : list , seed : int , split : str , async_mode : bool , batch_size : int
106+ input_path : str ,
107+ n_evolutions : int ,
108+ evolutions : list ,
109+ seed : int ,
110+ split : str ,
111+ async_mode : bool ,
112+ batch_size : int ,
113+ * args ,
114+ ** kwargs ,
100115) -> Dataset :
101116 dataset = load_dataset (input_path )
102117 n_before = len (dataset [split ])
@@ -146,6 +161,7 @@ def _parse_args() -> Namespace:
146161 parser .add_argument ("--async-mode" , action = "store_true" , help = "Enable asynchronous generation" )
147162 parser .add_argument ("--seed" , type = int , default = 0 )
148163 parser .add_argument ("--batch-size" , type = int , default = 4 )
164+ parser .add_argument ("--search-space" , type = str , default = None )
149165
150166 return parser .parse_args ()
151167
@@ -172,12 +188,21 @@ def main() -> None:
172188 logger .warning ("No evolutions selected. Exiting." )
173189 return
174190
191+ search_space = _choose_search_space (args .search_space )
192+
175193 process_func = _generate_fixed_evolutions
176194 if args .decide_for_me :
177195 process_func = _optimize_n_evolutions
178196
179197 dataset = process_func (
180- args .input_path , args .n_evolutions , evolutions , args .seed , args .split , args .async_mode , args .batch_size
198+ args .input_path ,
199+ args .n_evolutions ,
200+ evolutions ,
201+ args .seed ,
202+ args .split ,
203+ args .async_mode ,
204+ args .batch_size ,
205+ search_space ,
181206 )
182207 dataset .to_json (args .output_path )
183208
0 commit comments