Skip to content

Commit 9cffdbe

Browse files
committed
remove bruteforce sampling support
1 parent dafba96 commit 9cffdbe

File tree

3 files changed

+8
-12
lines changed

3 files changed

+8
-12
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class Pipeline:
4444
def __init__(
4545
self,
4646
nodes: list[NodeOptimizer] | list[InferenceNode],
47-
sampler: SamplerType = "brute",
47+
sampler: SamplerType = "tpe",
4848
seed: int | None = 42,
4949
) -> None:
5050
"""Initialize the pipeline optimizer.

autointent/custom_types/_types.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,9 @@ class Split:
8989
INTENTS = "intents"
9090

9191

92-
SamplerType = Literal["brute", "tpe", "random"]
92+
SamplerType = Literal["tpe", "random"]
9393
"""Hyperparameter tuning strategies:
9494
95-
- `brute`: :py:class:`optuna.samplers.BruteForceSampler`
9695
- `tpe`: :py:class:`optuna.samplers.TPESampler`
9796
- `random`: :py:class:`optuna.samplers.RandomSampler`
9897
"""

autointent/nodes/_node_optimizer.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363
def fit(
6464
self,
6565
context: Context,
66-
sampler: SamplerType = "brute",
66+
sampler: SamplerType = "tpe",
6767
n_trials: int | None = None,
6868
timeout: float | None = None,
6969
n_jobs: int = 1,
@@ -82,15 +82,12 @@ def fit(
8282
"""
8383
self._logger.info("Starting %s node optimization...", self.node_info.node_type.value)
8484

85+
n_trials = n_trials or 10
86+
8587
if sampler == "tpe":
8688
sampler_instance = optuna.samplers.TPESampler(seed=context.seed)
87-
n_trials = n_trials or 10
88-
elif sampler == "brute":
89-
sampler_instance = optuna.samplers.BruteForceSampler(seed=context.seed) # type: ignore[assignment]
90-
n_trials = None
9189
elif sampler == "random":
9290
sampler_instance = optuna.samplers.RandomSampler(seed=context.seed) # type: ignore[assignment]
93-
n_trials = n_trials or 10
9491
else:
9592
assert_never(sampler)
9693

@@ -101,7 +98,7 @@ def fit(
10198
sampler=sampler_instance,
10299
n_trials=n_trials,
103100
)
104-
self._counter = max(self._counter, finished_trials)
101+
self._counter = finished_trials # zero if study is newly created
105102

106103
optuna.logging.set_verbosity(optuna.logging.WARNING)
107104
obj = partial(self.objective, search_space=self.modules_search_spaces, context=context)
@@ -364,8 +361,8 @@ def load_or_create_study(
364361
study_name: str,
365362
context: Context,
366363
sampler: optuna.samplers.BaseSampler,
364+
n_trials: int,
367365
direction: str = "maximize",
368-
n_trials: int = 10,
369366
) -> tuple[optuna.Study, int, int]:
370367
"""Load an existing study or create a new one if it doesn't exist.
371368
@@ -396,7 +393,7 @@ def load_or_create_study(
396393
# Find the highest trial number to continue counting
397394
finished_trials = max(t.number for t in study.trials) + 1
398395
# Calculate remaining trials if n_trials is specified
399-
remaining_trials = n_trials if n_trials is None else max(0, n_trials - len(study.trials))
396+
remaining_trials = max(0, n_trials - len(study.trials))
400397

401398
context.load_optimization_info()
402399
return study, finished_trials, remaining_trials # noqa: TRY300

0 commit comments

Comments
 (0)