Skip to content

Commit 8227550

Browse files
committed
adding moo_lower_bounds to optuna
1 parent 26d479c commit 8227550

File tree

1 file changed

+37
-3
lines changed

1 file changed

+37
-3
lines changed

deephyper_benchmark/search/_mpi_doptuna.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ def __call__(self, study: optuna.study.Study, trial: FrozenTrial) -> None:
6161
)
6262

6363

64+
# Constraints
65+
def constraints(trial):
66+
return trial.user_attrs["constraints"]
67+
68+
6469
# Supported samplers
6570
supported_samplers = ["TPE", "CMAES", "NSGAII", "DUMMY", "BOTORCH", "QMC"]
6671
supported_pruners = ["NOP", "SHA", "HB", "MED"]
@@ -82,6 +87,7 @@ class MPIDistributedOptuna(Search):
8287
storage (Union[str, optuna.storages.BaseStorage], optional): Database used by Optuna. Defaults to ``None``.
8388
checkpoint (bool, optional): If results should be checkpointed regularly to the ``log_dir``. Defaults to ``True``.
8489
comm (MPI.Comm, optional): The MPI communicator. Defaults to ``None``.
90+
moo_lower_bounds ([type], optional): [description]. Defaults to ``None``.
8591
8692
Raises:
8793
ValueError: _description_
@@ -106,6 +112,7 @@ def __init__(
106112
checkpoint: bool = True,
107113
n_initial_points: int = None,
108114
comm: MPI.Comm = None,
115+
moo_lower_bounds=None,
109116
**kwargs,
110117
):
111118
super().__init__(problem, evaluator, random_state, log_dir, verbose)
@@ -135,28 +142,46 @@ def __init__(
135142
2 * len(self._problem) if n_initial_points is None else n_initial_points
136143
)
137144

145+
# Constraints
146+
self._moo_lower_bounds = moo_lower_bounds
147+
self._constraints_func = None
148+
if moo_lower_bounds is not None:
149+
if len(moo_lower_bounds) == n_objectives:
150+
self._constraints_func = constraints
151+
else:
152+
raise ValueError(
153+
f"moo_lower_bounds should be of length {n_objectives} but is of length {len(moo_lower_bounds)}"
154+
)
155+
138156
# Setup the sampler
139157
if isinstance(sampler, optuna.samplers.BaseSampler):
140158
pass
141159
elif isinstance(sampler, str):
142160
sampler_seed = self._random_state.randint(2**31)
143161
if sampler == "TPE":
144162
sampler = optuna.samplers.TPESampler(
145-
n_startup_trials=self._n_initial_points, seed=sampler_seed
163+
n_startup_trials=self._n_initial_points,
164+
seed=sampler_seed,
165+
constraints_func=self._constraints_func,
146166
)
147167
elif sampler == "CMAES":
148168
sampler = optuna.samplers.CmaEsSampler(
149169
n_startup_trials=self._n_initial_points, seed=sampler_seed
150170
)
151171
elif sampler == "NSGAII":
152-
sampler = optuna.samplers.NSGAIISampler(seed=sampler_seed)
172+
sampler = optuna.samplers.NSGAIISampler(
173+
seed=sampler_seed,
174+
constraints_func=self._constraints_func,
175+
)
153176
elif sampler == "DUMMY":
154177
sampler = optuna.samplers.RandomSampler(seed=sampler_seed)
155178
elif sampler == "BOTORCH":
156179
from optuna.integration import BoTorchSampler
157180

158181
sampler = BoTorchSampler(
159-
n_startup_trials=self._n_initial_points, seed=sampler_seed
182+
n_startup_trials=self._n_initial_points,
183+
seed=sampler_seed,
184+
constraints_func=self._constraints_func,
160185
)
161186
elif sampler == "QMC":
162187
sampler = optuna.samplers.QMCSampler(seed=sampler_seed)
@@ -253,6 +278,15 @@ def objective_wrapper(trial):
253278
)
254279

255280
# TODO: optuna constraint
281+
if self._moo_lower_bounds is not None:
282+
# https://optuna.readthedocs.io/en/stable/faq.html#how-can-i-optimize-a-model-with-some-constraints
283+
# Constraints which are considered feasible if less than or equal to zero.
284+
constraints = []
285+
for i, lbi in enumerate(self._moo_lower_bounds):
286+
if lbi is not None:
287+
ci = -(output["objective"][i] - lbi) # <= 0
288+
constraints.append(ci)
289+
trial.set_user_attr("constraints", tuple(constraints))
256290

257291
data = {f"p:{k}": v for k, v in config.items()}
258292
if isinstance(output["objective"], list) or isinstance(

0 commit comments

Comments
 (0)