Skip to content

Commit dafba96

Browse files
committed
implement new logic
1 parent fe2b79e commit dafba96

File tree

1 file changed

+68
-56
lines changed

1 file changed

+68
-56
lines changed

autointent/nodes/_node_optimizer.py

Lines changed: 68 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import itertools as it
55
import json
66
import logging
7+
import os
78
from copy import deepcopy
89
from functools import partial
910
from pathlib import Path
@@ -59,59 +60,60 @@ def __init__(
5960
self.validate_search_space(search_space)
6061
self.modules_search_spaces = search_space
6162

62-
def fit(self, context: Context, sampler: SamplerType = "brute", n_jobs: int = 1) -> None:
63+
def fit(
64+
self,
65+
context: Context,
66+
sampler: SamplerType = "brute",
67+
n_trials: int | None = None,
68+
timeout: float | None = None,
69+
n_jobs: int = 1,
70+
) -> None:
6371
"""Performs the optimization process for the node.
6472
6573
Args:
6674
context: The optimization context containing relevant data.
6775
sampler: The sampling strategy used for optimization.
76+
n_trials: Number of optuna trials.
77+
timeout: Number of secords for optimizing the whole node.
6878
n_jobs: The number of parallel jobs to run during optimization.
6979
7080
Raises:
7181
AssertionError: If an invalid sampler type is provided.
7282
"""
7383
self._logger.info("Starting %s node optimization...", self.node_info.node_type.value)
74-
for search_space in deepcopy(self.modules_search_spaces):
75-
self._counter: int = 0
76-
module_name = search_space.pop("module_name")
77-
n_trials = search_space.pop("n_trials", None)
78-
79-
if sampler == "tpe":
80-
sampler_instance = optuna.samplers.TPESampler(seed=context.seed)
81-
n_trials = n_trials or 10
82-
elif sampler == "brute":
83-
sampler_instance = optuna.samplers.BruteForceSampler(seed=context.seed) # type: ignore[assignment]
84-
n_trials = None
85-
elif sampler == "random":
86-
sampler_instance = optuna.samplers.RandomSampler(seed=context.seed) # type: ignore[assignment]
87-
n_trials = n_trials or 10
88-
else:
89-
assert_never(sampler)
90-
91-
if n_trials and (possible_combinations := self._n_possible_combinations(search_space)):
92-
n_trials = min(possible_combinations, n_trials)
9384

94-
study, finished_trials, n_trials = load_or_create_study(
95-
study_name=f"{self.node_info.node_type}_{module_name}",
96-
context=context,
97-
direction="maximize",
98-
sampler=sampler_instance,
99-
n_trials=n_trials,
100-
)
101-
self._counter = max(self._counter, finished_trials)
85+
if sampler == "tpe":
86+
sampler_instance = optuna.samplers.TPESampler(seed=context.seed)
87+
n_trials = n_trials or 10
88+
elif sampler == "brute":
89+
sampler_instance = optuna.samplers.BruteForceSampler(seed=context.seed) # type: ignore[assignment]
90+
n_trials = None
91+
elif sampler == "random":
92+
sampler_instance = optuna.samplers.RandomSampler(seed=context.seed) # type: ignore[assignment]
93+
n_trials = n_trials or 10
94+
else:
95+
assert_never(sampler)
96+
97+
study, finished_trials, n_trials = load_or_create_study(
98+
study_name=self.node_info.node_type,
99+
context=context,
100+
direction="maximize",
101+
sampler=sampler_instance,
102+
n_trials=n_trials,
103+
)
104+
self._counter = max(self._counter, finished_trials)
102105

103-
optuna.logging.set_verbosity(optuna.logging.WARNING)
104-
obj = partial(self.objective, module_name=module_name, search_space=search_space, context=context)
106+
optuna.logging.set_verbosity(optuna.logging.WARNING)
107+
obj = partial(self.objective, search_space=self.modules_search_spaces, context=context)
105108

106-
study.optimize(obj, n_trials=n_trials, n_jobs=n_jobs)
109+
study.optimize(obj, n_trials=n_trials, n_jobs=n_jobs, gc_after_trial=True, timeout=timeout)
107110

108111
self._logger.info("%s node optimization is finished!", self.node_info.node_type)
109112

110113
def objective(
111114
self,
112115
trial: Trial,
113-
module_name: str,
114-
search_space: dict[str, ParamSpaceInt | ParamSpaceFloat | list[Any]],
116+
search_space: list[dict[str, Any]],
115117
context: Context,
116118
) -> float:
117119
"""Defines the objective function for optimization.
@@ -125,13 +127,17 @@ def objective(
125127
Returns:
126128
The value of the target metric for the given trial.
127129
"""
128-
config = self.suggest(trial, search_space)
130+
module_name, module_hyperparams = self._suggest_module_and_hyperparams(trial, search_space)
129131

130-
self._logger.debug("Initializing %s module with config: %s", module_name, json.dumps(config))
131-
module = self.node_info.modules_available[module_name].from_context(context, **config)
132-
config.update(module.get_implicit_initialization_params())
132+
self._logger.debug("Initializing %s module with config: %s", module_name, json.dumps(module_hyperparams))
133+
module = self.node_info.modules_available[module_name].from_context(context, **module_hyperparams)
134+
module_hyperparams.update(module.get_implicit_initialization_params())
133135

134-
context.callback_handler.start_module(module_name=module.trial_name, num=self._counter, module_kwargs=config)
136+
context.callback_handler.start_module(
137+
module_name=module.trial_name,
138+
num=self._counter,
139+
module_kwargs=module_hyperparams,
140+
)
135141

136142
self._logger.debug("Scoring %s module...", module_name)
137143

@@ -148,7 +154,7 @@ def objective(
148154
context.optimization_info.log_module_optimization(
149155
node_type=self.node_info.node_type,
150156
module_name=module_name,
151-
module_params=config,
157+
module_params=module_hyperparams,
152158
metric_value=target_metric,
153159
metric_name=self.target_metric,
154160
metrics=quality_metrics,
@@ -166,30 +172,32 @@ def objective(
166172
self._counter += 1
167173
return target_metric
168174

169-
def suggest(self, trial: Trial, search_space: dict[str, Any | list[Any]]) -> dict[str, Any]:
170-
"""Suggests parameter values based on the search space.
171-
172-
Args:
173-
trial: The Optuna trial instance.
174-
search_space: A dictionary defining the parameter search space.
175-
176-
Returns:
177-
A dictionary containing the suggested parameter values.
178-
179-
Raises:
180-
TypeError: If an unsupported parameter search space type is encountered.
181-
"""
175+
def _suggest_module_and_hyperparams(
176+
self, trial: Trial, search_space: list[dict[str, Any]]
177+
) -> tuple[str, dict[str, Any]]:
178+
"""Sample module name and its hyperparams from given search space."""
179+
n_modules = len(search_space)
180+
id_module_chosen = trial.suggest_categorical("module_idx", list(range(n_modules)))
181+
module_chosen = deepcopy(search_space[id_module_chosen])
182+
module_name = module_chosen.pop("module_name")
183+
module_config = self._suggest_hyperparams(trial, f"{module_name}_{id_module_chosen}", module_chosen)
184+
return module_name, module_config
185+
186+
def _suggest_hyperparams(
187+
self, trial: Trial, module_name: str, search_space: dict[str, Any | list[Any]]
188+
) -> dict[str, Any]:
182189
res: dict[str, Any] = {}
183190

184191
for param_name, param_space in search_space.items():
192+
name = f"{module_name}_{param_name}"
185193
if isinstance(param_space, list):
186-
res[param_name] = trial.suggest_categorical(param_name, choices=param_space)
194+
res[param_name] = trial.suggest_categorical(name, choices=param_space)
187195
elif self._parse_param_space(param_space, ParamSpaceInt):
188-
res[param_name] = trial.suggest_int(param_name, **param_space)
196+
res[param_name] = trial.suggest_int(name, **param_space)
189197
elif self._parse_param_space(param_space, ParamSpaceFloat):
190-
res[param_name] = trial.suggest_float(param_name, **param_space)
198+
res[param_name] = trial.suggest_float(name, **param_space)
191199
else:
192-
msg = f"Unsupported type of param search space: {param_space}"
200+
msg = f"Unsupported type of param search space {name}: {param_space}"
193201
raise TypeError(msg)
194202
return res
195203

@@ -294,6 +302,10 @@ def validate_nodes_with_dataset(self, dataset: Dataset, mode: SearchSpaceValidat
294302
def validate_search_space(self, search_space: list[dict[str, Any]]) -> None:
295303
"""Check if search space is configured correctly."""
296304
validated_search_space = SearchSpaceConfig(search_space).model_dump()
305+
306+
if not bool(int(os.getenv("AUTOINTENT_EXTRA_VALIDATION", "0"))):
307+
return
308+
297309
for module_search_space in validated_search_space:
298310
module_search_space_no_optuna, module_name = self._reformat_search_space(deepcopy(module_search_space))
299311

0 commit comments

Comments
 (0)