Skip to content

Commit 8fe33ab

Browse files
committed
feat: added search space
1 parent 6e06ffa commit 8fe33ab

File tree

1 file changed

+33
-8
lines changed
  • autointent/generation/utterances/evolution

1 file changed

+33
-8
lines changed

autointent/generation/utterances/evolution/cli.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import logging
44
from argparse import ArgumentParser, Namespace
5+
from pathlib import Path
6+
from typing import Any
57

68
from datasets import concatenate_datasets
79

@@ -21,7 +23,7 @@
2123
ReasoningEvolution,
2224
)
2325

24-
# logging.basicConfig(level="INFO")
26+
logging.basicConfig(level="INFO")
2527
logger = logging.getLogger(__name__)
2628

2729
SEARCH_SPACE = [
@@ -54,6 +56,12 @@
5456
]
5557

5658

59+
def _choose_search_space(search_space: str | None) -> list[dict[str, Any]] | Path | str:
60+
if search_space is None:
61+
return SEARCH_SPACE
62+
return search_space
63+
64+
5765
def _optimize_n_evolutions(
5866
input_path: str,
5967
max_n_evolutions: int,
@@ -62,14 +70,14 @@ def _optimize_n_evolutions(
6270
split_train: str,
6371
async_mode: bool,
6472
batch_size: int,
65-
) -> tuple[Dataset, int]:
73+
search_space: list[dict[str, Any]] | Path | str,
74+
) -> Dataset:
6675
emb_config = EmbedderConfig(batch_size=16, device="cuda")
6776

6877
best_result = 0
6978
best_n = 0
7079
dataset = load_dataset(input_path)
7180
merge_dataset = load_dataset(input_path)
72-
k = 0.9
7381

7482
for n in range(max_n_evolutions):
7583
generator = UtteranceEvolver(Generator(), evolutions, seed, async_mode)
@@ -78,12 +86,11 @@ def _optimize_n_evolutions(
7886
)
7987
merge_dataset[split_train] = concatenate_datasets([merge_dataset[split_train], new_samples_dataset])
8088

81-
pipeline_optimizer = Pipeline.from_search_space(SEARCH_SPACE)
89+
pipeline_optimizer = Pipeline.from_search_space(search_space)
8290
pipeline_optimizer.set_config(emb_config)
8391
ctx = pipeline_optimizer.fit(merge_dataset)
8492
results = ctx.optimization_info.dump_evaluation_results()
85-
decision_metric = results["metrics"]["decision"][0] - k
86-
k -= 0.1
93+
decision_metric = results["metrics"]["decision"][0]
8794

8895
if decision_metric > best_result:
8996
best_result = decision_metric
@@ -96,7 +103,15 @@ def _optimize_n_evolutions(
96103

97104

98105
def _generate_fixed_evolutions(
99-
input_path: str, n_evolutions: int, evolutions: list, seed: int, split: str, async_mode: bool, batch_size: int
106+
input_path: str,
107+
n_evolutions: int,
108+
evolutions: list,
109+
seed: int,
110+
split: str,
111+
async_mode: bool,
112+
batch_size: int,
113+
*args,
114+
**kwargs,
100115
) -> Dataset:
101116
dataset = load_dataset(input_path)
102117
n_before = len(dataset[split])
@@ -146,6 +161,7 @@ def _parse_args() -> Namespace:
146161
parser.add_argument("--async-mode", action="store_true", help="Enable asynchronous generation")
147162
parser.add_argument("--seed", type=int, default=0)
148163
parser.add_argument("--batch-size", type=int, default=4)
164+
parser.add_argument("--search-space", type=str, default=None)
149165

150166
return parser.parse_args()
151167

@@ -172,12 +188,21 @@ def main() -> None:
172188
logger.warning("No evolutions selected. Exiting.")
173189
return
174190

191+
search_space = _choose_search_space(args.search_space)
192+
175193
process_func = _generate_fixed_evolutions
176194
if args.decide_for_me:
177195
process_func = _optimize_n_evolutions
178196

179197
dataset = process_func(
180-
args.input_path, args.n_evolutions, evolutions, args.seed, args.split, args.async_mode, args.batch_size
198+
args.input_path,
199+
args.n_evolutions,
200+
evolutions,
201+
args.seed,
202+
args.split,
203+
args.async_mode,
204+
args.batch_size,
205+
search_space,
181206
)
182207
dataset.to_json(args.output_path)
183208

0 commit comments

Comments
 (0)